]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : refactor unary ops (#405)
authorIvan Zdane <redacted>
Sun, 23 Jul 2023 19:44:13 +0000 (15:44 -0400)
committerGitHub <redacted>
Sun, 23 Jul 2023 19:44:13 +0000 (22:44 +0300)
* Add gitignore rule for temporary vim files

* ggml: refactor implementation of unary ops

* backends : adapt to ggml_unary_op

* ggml : fix assert number of ops

* ggml : hide ggml_set_unary_op from public API

---------

Co-authored-by: izdane <redacted>
Co-authored-by: Georgi Gerganov <redacted>
.gitignore
examples/mnist/main-mtl.m
include/ggml/ggml.h
src/ggml-cuda.cu
src/ggml-metal.m
src/ggml.c

index c7a8f76b0b8161542d907518f2f9d0904bc8e3f6..467c19dafba4048a13c6180ac003876a772989dd 100644 (file)
@@ -29,3 +29,5 @@ zig-out/
 zig-cache/
 
 *.dot
+
+*.sw?
index cc4e6c25ab8c28ccf87b364211a06499abf482a9..bee8d4c902b8eaae0c9b968ac815c083297c497f 100644 (file)
@@ -340,22 +340,32 @@ int mnist_mtl_eval(
 
                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                 } break;
-            case GGML_OP_RELU:
-                {
-                    if (encoder == nil) {
-                        encoder = [command_buffer computeCommandEncoder];
-                    }
-
-                    id<MTLBuffer> id_src = mnist_mtl_get_buffer(ctx, gf->nodes[i]->src[0], &offs_src0);
-                    id<MTLBuffer> id_dst = mnist_mtl_get_buffer(ctx, gf->nodes[i],       &offs_dst);
-
-                    [encoder setComputePipelineState:ctx->pipeline_relu];
-                    [encoder setBuffer:id_src offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_dst offset:offs_dst  atIndex:1];
-
-                    const int64_t n = ggml_nelements(gf->nodes[i]);
-
-                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            case GGML_OP_UNARY:
+                switch (ggml_get_unary_op(gf->nodes[i])) {
+                    case GGML_UNARY_OP_RELU:
+                        {
+                            if (encoder == nil) {
+                                encoder = [command_buffer computeCommandEncoder];
+                            }
+
+                            id<MTLBuffer> id_src = mnist_mtl_get_buffer(ctx, gf->nodes[i]->src[0], &offs_src0);
+                            id<MTLBuffer> id_dst = mnist_mtl_get_buffer(ctx, gf->nodes[i],       &offs_dst);
+
+                            [encoder setComputePipelineState:ctx->pipeline_relu];
+                            [encoder setBuffer:id_src offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst offset:offs_dst  atIndex:1];
+
+                            const int64_t n = ggml_nelements(gf->nodes[i]);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
+                    default:
+                        {
+                            fprintf(stderr, "%s: node %3d, op = %8s, unary op %d not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op), (int) ggml_get_unary_op(gf->nodes[i]));
+                            GGML_ASSERT(false);
+                            return -1;
+                        }
+                        break;
                 } break;
             case GGML_OP_SOFT_MAX:
                 {
@@ -435,9 +445,11 @@ int mnist_mtl_eval(
                     [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
                 } break;
             default:
-                fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
-                GGML_ASSERT(false);
-                return -1;
+                {
+                    fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
+                    GGML_ASSERT(false);
+                    return -1;
+                }
         }
     }
 
index def39e35544f2e0ccfaa20a0f47a238235886631..de44fba9e0961886ba2d35046e60dc7934f05961 100644 (file)
@@ -330,16 +330,6 @@ extern "C" {
         GGML_OP_ARGMAX,
         GGML_OP_REPEAT,
         GGML_OP_REPEAT_BACK,
-        GGML_OP_ABS,
-        GGML_OP_SGN,
-        GGML_OP_NEG,
-        GGML_OP_STEP,
-        GGML_OP_TANH,
-        GGML_OP_ELU,
-        GGML_OP_RELU,
-        GGML_OP_GELU,
-        GGML_OP_GELU_QUICK,
-        GGML_OP_SILU,
         GGML_OP_SILU_BACK,
         GGML_OP_NORM, // normalize
         GGML_OP_RMS_NORM,
@@ -378,6 +368,8 @@ extern "C" {
         GGML_OP_WIN_PART,
         GGML_OP_WIN_UNPART,
 
+        GGML_OP_UNARY,
+
         GGML_OP_MAP_UNARY,
         GGML_OP_MAP_BINARY,
 
@@ -391,6 +383,18 @@ extern "C" {
         GGML_OP_COUNT,
     };
 
+    enum ggml_unary_op {
+        GGML_UNARY_OP_ABS,
+        GGML_UNARY_OP_SGN,
+        GGML_UNARY_OP_NEG,
+        GGML_UNARY_OP_STEP,
+        GGML_UNARY_OP_TANH,
+        GGML_UNARY_OP_ELU,
+        GGML_UNARY_OP_RELU,
+        GGML_UNARY_OP_GELU,
+        GGML_UNARY_OP_GELU_QUICK,
+        GGML_UNARY_OP_SILU,
+    };
 
     // ggml object
     struct ggml_object {
@@ -535,6 +539,7 @@ extern "C" {
 
     GGML_API const char * ggml_type_name(enum ggml_type type);
     GGML_API const char * ggml_op_name  (enum ggml_op   op);
+    GGML_API const char * ggml_op_symbol(enum ggml_op   op);
 
     GGML_API size_t  ggml_element_size(const struct ggml_tensor * tensor);
 
@@ -618,9 +623,11 @@ extern "C" {
     GGML_API void *  ggml_get_data    (const struct ggml_tensor * tensor);
     GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
 
-    GGML_API const char *         ggml_get_name(const struct ggml_tensor * tensor);
-    GGML_API struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name);
-    GGML_API struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...);
+    GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
+
+    GGML_API const char *         ggml_get_name   (const struct ggml_tensor * tensor);
+    GGML_API struct ggml_tensor * ggml_set_name   (      struct ggml_tensor * tensor, const char * name);
+    GGML_API struct ggml_tensor * ggml_format_name(      struct ggml_tensor * tensor, const char * fmt, ...);
 
     //
     // operations on tensors with backpropagation
@@ -1285,6 +1292,16 @@ extern "C" {
     typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
     typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
 
+    GGML_API struct ggml_tensor * ggml_unary(
+            struct ggml_context * ctx,
+             struct ggml_tensor * a,
+             enum ggml_unary_op op);
+
+    GGML_API struct ggml_tensor * ggml_unary_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        enum ggml_unary_op op);
+
     GGML_API struct ggml_tensor * ggml_map_unary_f32(
             struct ggml_context        * ctx,
             struct ggml_tensor         * a,
index 6fb55d838dfb3e4a7ec505f17a3adf334e085ffe..0ab06ec9a4b6d6ebceab943763143af6b75b417a 100644 (file)
@@ -3908,18 +3908,23 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
             }
             func = ggml_cuda_mul;
             break;
-        case GGML_OP_GELU:
-            if (!any_on_device) {
-                return false;
-            }
-            func = ggml_cuda_gelu;
-            break;
-        case GGML_OP_SILU:
-            if (!any_on_device) {
-                return false;
-            }
-            func = ggml_cuda_silu;
-            break;
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(tensor)) {
+                case GGML_UNARY_OP_GELU:
+                    if (!any_on_device) {
+                        return false;
+                    }
+                    func = ggml_cuda_gelu;
+                    break;
+                case GGML_UNARY_OP_SILU:
+                    if (!any_on_device) {
+                        return false;
+                    }
+                    func = ggml_cuda_silu;
+                    break;
+                default:
+                    return false;
+            } break;
         case GGML_OP_NORM:
             if (!any_on_device) {
                 return false;
index bf3f68fe45726ca93a10b9293200d11c717e7789..1fd6e857ffe6138dc2a5be702bebcd57eb5c6fc5 100644 (file)
@@ -519,48 +519,56 @@ void ggml_metal_graph_compute(
 
                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
-                    case GGML_OP_SILU:
-                        {
-                            if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
-                            }
-
-                            [encoder setComputePipelineState:ctx->pipeline_silu];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-
-                            const int64_t n = ggml_nelements(dst);
-
-                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                        } break;
-                    case GGML_OP_RELU:
-                        {
-                            if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
-                            }
-
-                            [encoder setComputePipelineState:ctx->pipeline_relu];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-
-                            const int64_t n = ggml_nelements(dst);
-
-                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                    case GGML_OP_UNARY:
+                        switch (ggml_get_unary_op(gf->nodes[i])) {
+                            case GGML_UNARY_OP_SILU:
+                                {
+                                    if (encoder == nil) {
+                                        encoder = [command_buffer computeCommandEncoder];
+                                    }
+
+                                    [encoder setComputePipelineState:ctx->pipeline_silu];
+                                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                                    const int64_t n = ggml_nelements(dst);
+
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                } break;
+                            case GGML_UNARY_OP_RELU:
+                                {
+                                    if (encoder == nil) {
+                                        encoder = [command_buffer computeCommandEncoder];
+                                    }
+
+                                    [encoder setComputePipelineState:ctx->pipeline_relu];
+                                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                                    const int64_t n = ggml_nelements(dst);
+
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                } break;
+                            case GGML_UNARY_OP_GELU:
+                                {
+                                    if (encoder == nil) {
+                                        encoder = [command_buffer computeCommandEncoder];
+                                    }
+
+                                    [encoder setComputePipelineState:ctx->pipeline_gelu];
+                                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                                    const int64_t n = ggml_nelements(dst);
+
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                } break;
+                            default:
+                                {
+                                    fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+                                    GGML_ASSERT(false);
+                                }
                         } break;
-                    case GGML_OP_GELU:
-                    {
-                            if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
-                            }
-
-                            [encoder setComputePipelineState:ctx->pipeline_gelu];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-
-                            const int64_t n = ggml_nelements(dst);
-
-                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                    } break;
                     case GGML_OP_SOFT_MAX:
                         {
                             if (encoder == nil) {
@@ -979,8 +987,10 @@ void ggml_metal_graph_compute(
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
                     default:
-                        fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
-                        GGML_ASSERT(false);
+                        {
+                            fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+                            GGML_ASSERT(false);
+                        }
                 }
             }
 
index 90f32a57500b29c9a89d11a0147b29b210186bfb..960b8057709a987e8aa83395fbe51a5c597fe21f 100644 (file)
@@ -3760,16 +3760,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "ARGMAX",
     "REPEAT",
     "REPEAT_BACK",
-    "ABS",
-    "SGN",
-    "NEG",
-    "STEP",
-    "TANH",
-    "ELU",
-    "RELU",
-    "GELU",
-    "GELU_QUICK",
-    "SILU",
     "SILU_BACK",
     "NORM",
     "RMS_NORM",
@@ -3808,6 +3798,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "WIN_PART",
     "WIN_UNPART",
 
+    "UNARY",
+
     "MAP_UNARY",
     "MAP_BINARY",
 
@@ -3819,7 +3811,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
+static_assert(GGML_OP_COUNT == 59, "GGML_OP_COUNT != 59");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -3840,16 +3832,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "argmax(x)",
     "repeat(x)",
     "repeat_back(x)",
-    "abs(x)",
-    "sgn(x)",
-    "-x",
-    "step(x)",
-    "tanh(x)",
-    "elu(x)",
-    "relu(x)",
-    "gelu(x)",
-    "gelu_quick(x)",
-    "silu(x)",
     "silu_back(x)",
     "norm(x)",
     "rms_norm(x)",
@@ -3888,6 +3870,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "win_part(x)",
     "win_unpart(x)",
 
+    "unary(x)",
+
     "f(x)",
     "f(x,y)",
 
@@ -3899,7 +3883,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
+static_assert(GGML_OP_COUNT == 59, "GGML_OP_COUNT != 59");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -4155,6 +4139,10 @@ const char * ggml_op_name(enum ggml_op op) {
     return GGML_OP_NAME[op];
 }
 
+const char * ggml_op_symbol(enum ggml_op op) {
+    return GGML_OP_SYMBOL[op];
+}
+
 size_t ggml_element_size(const struct ggml_tensor * tensor) {
     return GGML_TYPE_SIZE[tensor->type];
 }
@@ -4635,6 +4623,21 @@ static struct ggml_tensor * ggml_new_tensor_impl(
     return result;
 }
 
+static void ggml_set_op_params(struct ggml_tensor * tensor, const void * params, size_t params_size) {
+    assert(params_size <= GGML_MAX_OP_PARAMS);
+    memcpy(tensor->op_params, params, params_size);
+}
+
+static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_t i) {
+    assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
+    return ((const int32_t *)(tensor->op_params))[i];
+}
+
+static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) {
+    assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
+    ((int32_t *)(tensor->op_params))[i] = value;
+}
+
 struct ggml_tensor * ggml_new_tensor(
         struct ggml_context * ctx,
         enum   ggml_type type,
@@ -4966,6 +4969,16 @@ float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
     return (float *)(tensor->data);
 }
 
+enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor->op == GGML_OP_UNARY);
+    return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
+}
+
+static void ggml_set_unary_op(struct ggml_tensor * tensor, enum ggml_unary_op op) {
+    GGML_ASSERT(tensor->op = GGML_OP_UNARY);
+    ggml_set_op_params_i32(tensor, 0, (int32_t) op);
+}
+
 const char * ggml_get_name(const struct ggml_tensor * tensor) {
     return tensor->name;
 }
@@ -4984,11 +4997,6 @@ struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char *
     return tensor;
 }
 
-static void ggml_set_op_params(struct ggml_tensor * tensor, const void * params, size_t params_size) {
-    assert(params_size <= GGML_MAX_OP_PARAMS);
-    memcpy(tensor->op_params, params, params_size);
-}
-
 struct ggml_tensor * ggml_view_tensor(
         struct ggml_context * ctx,
         const struct ggml_tensor * src) {
@@ -5573,333 +5581,142 @@ struct ggml_tensor * ggml_repeat_back(
 
 // ggml_abs
 
-static struct ggml_tensor * ggml_abs_impl(
-        struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    result->op   = GGML_OP_ABS;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 struct ggml_tensor * ggml_abs(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_abs_impl(ctx, a, false);
+    return ggml_unary(ctx, a, GGML_UNARY_OP_ABS);
 }
 
 struct ggml_tensor * ggml_abs_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_abs_impl(ctx, a, true);
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ABS);
 }
 
-
 // ggml_sgn
 
-static struct ggml_tensor * ggml_sgn_impl(
-        struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    result->op   = GGML_OP_SGN;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 struct ggml_tensor * ggml_sgn(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_sgn_impl(ctx, a, false);
+    return ggml_unary(ctx, a, GGML_UNARY_OP_SGN);
 }
 
 struct ggml_tensor * ggml_sgn_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_sgn_impl(ctx, a, true);
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SGN);
 }
 
 // ggml_neg
 
-static struct ggml_tensor * ggml_neg_impl(
-        struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    result->op   = GGML_OP_NEG;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 struct ggml_tensor * ggml_neg(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_neg_impl(ctx, a, false);
+    return ggml_unary(ctx, a, GGML_UNARY_OP_NEG);
 }
 
 struct ggml_tensor * ggml_neg_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_neg_impl(ctx, a, true);
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_NEG);
 }
 
 // ggml_step
 
-static struct ggml_tensor * ggml_step_impl(
-        struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    result->op   = GGML_OP_STEP;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 struct ggml_tensor * ggml_step(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_step_impl(ctx, a, false);
+    return ggml_unary(ctx, a, GGML_UNARY_OP_STEP);
 }
 
 struct ggml_tensor * ggml_step_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_step_impl(ctx, a, true);
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_STEP);
 }
 
 // ggml_tanh
 
-static struct ggml_tensor * ggml_tanh_impl(
-        struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    result->op   = GGML_OP_TANH;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 struct ggml_tensor * ggml_tanh(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_tanh_impl(ctx, a, false);
+    return ggml_unary(ctx, a, GGML_UNARY_OP_TANH);
 }
 
 struct ggml_tensor * ggml_tanh_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_tanh_impl(ctx, a, true);
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TANH);
 }
 
 // ggml_elu
 
-static struct ggml_tensor * ggml_elu_impl(
-        struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    result->op   = GGML_OP_ELU;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 struct ggml_tensor * ggml_elu(
     struct ggml_context * ctx,
     struct ggml_tensor  * a) {
-    return ggml_elu_impl(ctx, a, false);
+    return ggml_unary(ctx, a, GGML_UNARY_OP_ELU);
 }
 
 struct ggml_tensor * ggml_elu_inplace(
     struct ggml_context * ctx,
     struct ggml_tensor  * a) {
-    return ggml_elu_impl(ctx, a, true);
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ELU);
 }
 
 // ggml_relu
 
-static struct ggml_tensor * ggml_relu_impl(
-        struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    result->op   = GGML_OP_RELU;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 struct ggml_tensor * ggml_relu(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_relu_impl(ctx, a, false);
+    return ggml_unary(ctx, a, GGML_UNARY_OP_RELU);
 }
 
 struct ggml_tensor * ggml_relu_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_relu_impl(ctx, a, true);
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
 }
 
 // ggml_gelu
 
-static struct ggml_tensor * ggml_gelu_impl(
-        struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    result->op   = GGML_OP_GELU;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 struct ggml_tensor * ggml_gelu(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_gelu_impl(ctx, a, false);
+    return ggml_unary(ctx, a, GGML_UNARY_OP_GELU);
 }
 
 struct ggml_tensor * ggml_gelu_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_gelu_impl(ctx, a, true);
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU);
 }
 
 // ggml_gelu_quick
 
-static struct ggml_tensor * ggml_gelu_quick_impl(
-        struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    result->op   = GGML_OP_GELU_QUICK;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 struct ggml_tensor * ggml_gelu_quick(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_gelu_quick_impl(ctx, a, false);
+    return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_QUICK);
 }
 
 struct ggml_tensor * ggml_gelu_quick_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_gelu_quick_impl(ctx, a, true);
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_QUICK);
 }
 
 // ggml_silu
 
-static struct ggml_tensor * ggml_silu_impl(
-        struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    result->op   = GGML_OP_SILU;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 struct ggml_tensor * ggml_silu(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_silu_impl(ctx, a, false);
+    return ggml_unary(ctx, a, GGML_UNARY_OP_SILU);
 }
 
 struct ggml_tensor * ggml_silu_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_silu_impl(ctx, a, true);
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU);
 }
 
 // ggml_silu_back
@@ -7377,6 +7194,44 @@ struct ggml_tensor * ggml_win_unpart(
     return result;
 }
 
+// gmml_unary
+
+static struct ggml_tensor * ggml_unary_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        enum ggml_unary_op op,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_set_unary_op(result, op);
+
+    result->op   = GGML_OP_UNARY;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_unary(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        enum ggml_unary_op op) {
+    return ggml_unary_impl(ctx, a, op, false);
+}
+
+struct ggml_tensor * ggml_unary_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        enum ggml_unary_op op) {
+    return ggml_unary_impl(ctx, a, op, true);
+}
+
 // ggml_map_unary
 
 static struct ggml_tensor * ggml_map_unary_impl_f32(
@@ -10127,7 +9982,6 @@ static void ggml_compute_forward_silu(
     }
 }
 
-
 // ggml_compute_forward_silu_back
 
 static void ggml_compute_forward_silu_back_f32(
@@ -14172,6 +14026,62 @@ static void ggml_compute_forward_win_unpart(
     }
 }
 
+//gmml_compute_forward_unary
+
+static void ggml_compute_forward_unary(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    const enum ggml_unary_op op = ggml_get_unary_op(dst);
+
+    switch (op) {
+        case GGML_UNARY_OP_ABS:
+            {
+                ggml_compute_forward_abs(params, src0, dst);
+            } break;
+        case GGML_UNARY_OP_SGN:
+            {
+                ggml_compute_forward_sgn(params, src0, dst);
+            } break;
+        case GGML_UNARY_OP_NEG:
+            {
+                ggml_compute_forward_neg(params, src0, dst);
+            } break;
+        case GGML_UNARY_OP_STEP:
+            {
+                ggml_compute_forward_step(params, src0, dst);
+            } break;
+        case GGML_UNARY_OP_TANH:
+            {
+                ggml_compute_forward_tanh(params, src0, dst);
+            } break;
+        case GGML_UNARY_OP_ELU:
+            {
+                ggml_compute_forward_elu(params, src0, dst);
+            } break;
+        case GGML_UNARY_OP_RELU:
+            {
+                ggml_compute_forward_relu(params, src0, dst);
+            } break;
+        case GGML_UNARY_OP_GELU:
+            {
+                ggml_compute_forward_gelu(params, src0, dst);
+            } break;
+        case GGML_UNARY_OP_GELU_QUICK:
+            {
+                ggml_compute_forward_gelu_quick(params, src0, dst);
+            } break;
+        case GGML_UNARY_OP_SILU:
+            {
+                ggml_compute_forward_silu(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_map_unary
 
 static void ggml_compute_forward_map_unary_f32(
@@ -14732,46 +14642,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_repeat_back(params, tensor->src[0], tensor);
             } break;
-        case GGML_OP_ABS:
-            {
-                ggml_compute_forward_abs(params, tensor->src[0], tensor);
-            } break;
-        case GGML_OP_SGN:
-            {
-                ggml_compute_forward_sgn(params, tensor->src[0], tensor);
-            } break;
-        case GGML_OP_NEG:
-            {
-                ggml_compute_forward_neg(params, tensor->src[0], tensor);
-            } break;
-        case GGML_OP_STEP:
-            {
-                ggml_compute_forward_step(params, tensor->src[0], tensor);
-            } break;
-        case GGML_OP_TANH:
-            {
-                ggml_compute_forward_tanh(params, tensor->src[0], tensor);
-            } break;
-        case GGML_OP_ELU:
-            {
-                ggml_compute_forward_elu(params, tensor->src[0], tensor);
-            } break;
-        case GGML_OP_RELU:
-            {
-                ggml_compute_forward_relu(params, tensor->src[0], tensor);
-            } break;
-        case GGML_OP_GELU:
-            {
-                ggml_compute_forward_gelu(params, tensor->src[0], tensor);
-            } break;
-        case GGML_OP_GELU_QUICK:
-            {
-                ggml_compute_forward_gelu_quick(params, tensor->src[0], tensor);
-            } break;
-        case GGML_OP_SILU:
-            {
-                ggml_compute_forward_silu(params, tensor->src[0], tensor);
-            } break;
         case GGML_OP_SILU_BACK:
             {
                 ggml_compute_forward_silu_back(params, tensor->src[0], tensor->src[1], tensor);
@@ -14914,6 +14784,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_win_unpart(params, tensor->src[0], tensor);
             } break;
+        case GGML_OP_UNARY:
+            {
+                ggml_compute_forward_unary(params, tensor->src[0], tensor);
+            } break;
         case GGML_OP_MAP_UNARY:
             {
                 ggml_unary_op_f32_t fun;
@@ -15162,73 +15036,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             inplace);
                 }
             } break;
-        case GGML_OP_ABS:
-            {
-                if (src0->grad) {
-                    src0->grad =
-                        ggml_add_impl(ctx,
-                                src0->grad,
-                                ggml_mul(ctx,
-                                    ggml_sgn(ctx, src0),
-                                    tensor->grad),
-                                inplace);
-                }
-            } break;
-        case GGML_OP_SGN:
-            {
-                if (src0->grad) {
-                    // noop
-                }
-            } break;
-        case GGML_OP_NEG:
-            {
-                if (src0->grad) {
-                    src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace);
-                }
-            } break;
-        case GGML_OP_STEP:
-            {
-                if (src0->grad) {
-                    // noop
-                }
-            } break;
-        case GGML_OP_TANH:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
-        case GGML_OP_ELU:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
-        case GGML_OP_RELU:
-            {
-                if (src0->grad) {
-                    src0->grad = ggml_add_impl(ctx,
-                            src0->grad,
-                            ggml_mul(ctx,
-                                ggml_step(ctx, src0),
-                                tensor->grad),
-                            inplace);
-                }
-            } break;
-        case GGML_OP_GELU:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
-        case GGML_OP_GELU_QUICK:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
-        case GGML_OP_SILU:
-            {
-                // necessary for llama
-                if (src0->grad) {
-                    src0->grad = ggml_add_impl(ctx,
-                            src0->grad,
-                            ggml_silu_back(ctx, src0, tensor->grad),
-                            inplace);
-                }
-            } break;
         case GGML_OP_SILU_BACK:
             {
                 GGML_ASSERT(false); // TODO: not implemented
@@ -15745,6 +15552,80 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             } break;
         case GGML_OP_WIN_PART:
         case GGML_OP_WIN_UNPART:
+        case GGML_OP_UNARY:
+            {
+                switch (ggml_get_unary_op(tensor)) {
+                    case GGML_UNARY_OP_ABS:
+                        {
+                            if (src0->grad) {
+                                src0->grad =
+                                    ggml_add_impl(ctx,
+                                            src0->grad,
+                                            ggml_mul(ctx,
+                                                ggml_sgn(ctx, src0),
+                                                tensor->grad),
+                                            inplace);
+                            }
+                        } break;
+                    case GGML_UNARY_OP_SGN:
+                        {
+                            if (src0->grad) {
+                                // noop
+                            }
+                        } break;
+                    case GGML_UNARY_OP_NEG:
+                        {
+                            if (src0->grad) {
+                                src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace);
+                            }
+                        } break;
+                    case GGML_UNARY_OP_STEP:
+                        {
+                            if (src0->grad) {
+                                // noop
+                            }
+                        } break;
+                    case GGML_UNARY_OP_TANH:
+                        {
+                            GGML_ASSERT(false); // TODO: not implemented
+                        } break;
+                    case GGML_UNARY_OP_ELU:
+                        {
+                            GGML_ASSERT(false); // TODO: not implemented
+                        } break;
+                    case GGML_UNARY_OP_RELU:
+                        {
+                            if (src0->grad) {
+                                src0->grad = ggml_add_impl(ctx,
+                                        src0->grad,
+                                        ggml_mul(ctx,
+                                            ggml_step(ctx, src0),
+                                            tensor->grad),
+                                        inplace);
+                            }
+                        } break;
+                    case GGML_UNARY_OP_GELU:
+                        {
+                            GGML_ASSERT(false); // TODO: not implemented
+                        } break;
+                    case GGML_UNARY_OP_GELU_QUICK:
+                        {
+                            GGML_ASSERT(false); // TODO: not implemented
+                        } break;
+                    case GGML_UNARY_OP_SILU:
+                        {
+                            // necessary for llama
+                            if (src0->grad) {
+                                src0->grad = ggml_add_impl(ctx,
+                                        src0->grad,
+                                        ggml_silu_back(ctx, src0, tensor->grad),
+                                        inplace);
+                            }
+                        } break;
+                    default:
+                        GGML_ASSERT(false);
+                }
+            } break;
         case GGML_OP_MAP_UNARY:
         case GGML_OP_MAP_BINARY:
         case GGML_OP_MAP_CUSTOM1:
@@ -16024,8 +15905,8 @@ static void clear_numa_thread_affinity(void) {
 #else
 // TODO: Windows etc.
 // (the linux implementation may also work on BSD, someone should test)
-void set_numa_thread_affinity(int thread_n, int n_threads) { UNUSED(thread_n); UNUSED(n_threads);  }
-void clear_numa_thread_affinity(void) {}
+static void set_numa_thread_affinity(int thread_n, int n_threads) { UNUSED(thread_n); UNUSED(n_threads);  }
+static void clear_numa_thread_affinity(void) {}
 #endif
 
 struct ggml_compute_state_shared {
@@ -16237,21 +16118,34 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
             case GGML_OP_ARGMAX:
             case GGML_OP_REPEAT:
             case GGML_OP_REPEAT_BACK:
-            case GGML_OP_ABS:
-            case GGML_OP_SGN:
-            case GGML_OP_NEG:
-            case GGML_OP_STEP:
-            case GGML_OP_TANH:
-            case GGML_OP_ELU:
-            case GGML_OP_RELU:
-                {
+            {
                     n_tasks = 1;
                 } break;
-            case GGML_OP_MUL:
-            case GGML_OP_GELU:
-            case GGML_OP_GELU_QUICK:
-            case GGML_OP_SILU:
+
+            case GGML_OP_UNARY:
+                {
+                    switch (ggml_get_unary_op(node)) {
+                        case GGML_UNARY_OP_ABS:
+                        case GGML_UNARY_OP_SGN:
+                        case GGML_UNARY_OP_NEG:
+                        case GGML_UNARY_OP_STEP:
+                        case GGML_UNARY_OP_TANH:
+                        case GGML_UNARY_OP_ELU:
+                        case GGML_UNARY_OP_RELU:
+                            {
+                                n_tasks = 1;
+                            } break;
+
+                        case GGML_UNARY_OP_GELU:
+                        case GGML_UNARY_OP_GELU_QUICK:
+                        case GGML_UNARY_OP_SILU:
+                            {
+                                n_tasks = n_threads;
+                            } break;
+                    }
+                } break;
             case GGML_OP_SILU_BACK:
+            case GGML_OP_MUL:
             case GGML_OP_NORM:
             case GGML_OP_RMS_NORM:
             case GGML_OP_RMS_NORM_BACK:
@@ -17122,7 +17016,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
         GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
                 i,
                 node->ne[0], node->ne[1], node->ne[2],
-                GGML_OP_NAME[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
+                ggml_op_name(node->op), node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
                 (double) node->perf_cycles  / (double) ggml_cycles_per_ms(),
                 (double) node->perf_cycles  / (double) ggml_cycles_per_ms() / (double) node->perf_runs,
                 (double) node->perf_time_us / 1000.0,
@@ -17136,7 +17030,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
         GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
                 i,
                 node->ne[0], node->ne[1],
-                GGML_OP_NAME[node->op]);
+                ggml_op_name(node->op));
     }
 
     for (int i = 0; i < GGML_OP_COUNT; i++) {
@@ -17144,7 +17038,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
             continue;
         }
 
-        GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_NAME[i], (double) perf_total_per_op_us[i] / 1000.0);
+        GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", ggml_op_name(i), (double) perf_total_per_op_us[i] / 1000.0);
     }
 
     GGML_PRINT("========================================\n");
@@ -17238,13 +17132,13 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
         }
 
         if (node->n_dims == 2) {
-            fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], GGML_OP_SYMBOL[node->op]);
+            fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], ggml_op_symbol(node->op));
         } else {
-            fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], node->ne[2], GGML_OP_SYMBOL[node->op]);
+            fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], node->ne[2], ggml_op_symbol(node->op));
         }
 
         if (node->grad) {
-            fprintf(fp, " | <g>%s\"; ]\n", GGML_OP_SYMBOL[node->grad->op]);
+            fprintf(fp, " | <g>%s\"; ]\n", ggml_op_symbol(node->grad->op));
         } else {
             fprintf(fp, "\"; ]\n");
         }