]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : add more generic custom op, remove deprecated custom ops (ggml/1183)
authorDiego Devesa <redacted>
Wed, 9 Apr 2025 10:31:34 +0000 (12:31 +0200)
committerGeorgi Gerganov <redacted>
Thu, 10 Apr 2025 21:17:47 +0000 (00:17 +0300)
* ggml : add more generic ggml_custom op

* ggml : remove deprecated custom ops

ggml/include/ggml.h
ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cpu/ops.cpp
ggml/src/ggml-cpu/ops.h
ggml/src/ggml-impl.h
ggml/src/ggml.c

index 452c967b0a637e3da4905605d3657bab8525ba82..a5447ecdf686a93e51245a0dda6db96a4bd52588 100644 (file)
@@ -507,17 +507,12 @@ extern "C" {
 
         GGML_OP_UNARY,
 
-        GGML_OP_MAP_UNARY,
-        GGML_OP_MAP_BINARY,
-
-        GGML_OP_MAP_CUSTOM1_F32,
-        GGML_OP_MAP_CUSTOM2_F32,
-        GGML_OP_MAP_CUSTOM3_F32,
-
         GGML_OP_MAP_CUSTOM1,
         GGML_OP_MAP_CUSTOM2,
         GGML_OP_MAP_CUSTOM3,
 
+        GGML_OP_CUSTOM,
+
         GGML_OP_CROSS_ENTROPY_LOSS,
         GGML_OP_CROSS_ENTROPY_LOSS_BACK,
         GGML_OP_OPT_STEP_ADAMW,
@@ -1916,83 +1911,6 @@ extern "C" {
 
     // custom operators
 
-    typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
-    typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
-
-    typedef void (*ggml_custom1_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *);
-    typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
-    typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32(
-            struct ggml_context        * ctx,
-            struct ggml_tensor         * a,
-                   ggml_unary_op_f32_t   fun),
-        "use ggml_map_custom1 instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32(
-            struct ggml_context        * ctx,
-            struct ggml_tensor         * a,
-                   ggml_unary_op_f32_t   fun),
-        "use ggml_map_custom1_inplace instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32(
-            struct ggml_context         * ctx,
-            struct ggml_tensor          * a,
-            struct ggml_tensor          * b,
-                   ggml_binary_op_f32_t   fun),
-        "use ggml_map_custom2 instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32(
-            struct ggml_context         * ctx,
-            struct ggml_tensor          * a,
-            struct ggml_tensor          * b,
-                   ggml_binary_op_f32_t   fun),
-        "use ggml_map_custom2_inplace instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-                   ggml_custom1_op_f32_t   fun),
-        "use ggml_map_custom1 instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-                   ggml_custom1_op_f32_t   fun),
-        "use ggml_map_custom1_inplace instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-            struct ggml_tensor           * b,
-                   ggml_custom2_op_f32_t   fun),
-        "use ggml_map_custom2 instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-            struct ggml_tensor           * b,
-                   ggml_custom2_op_f32_t   fun),
-        "use ggml_map_custom2_inplace instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-            struct ggml_tensor           * b,
-            struct ggml_tensor           * c,
-                   ggml_custom3_op_f32_t   fun),
-        "use ggml_map_custom3 instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-            struct ggml_tensor           * b,
-            struct ggml_tensor           * c,
-                   ggml_custom3_op_f32_t   fun),
-        "use ggml_map_custom3_inplace instead");
-
-    // custom operators v2
-
     typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
     typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
     typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
@@ -2048,6 +1966,30 @@ extern "C" {
             int                     n_tasks,
             void                  * userdata);
 
+    typedef void (*ggml_custom_op_t)(struct ggml_tensor * dst , int ith, int nth, void * userdata);
+
+    GGML_API struct ggml_tensor * ggml_custom_4d(
+            struct ggml_context * ctx,
+            enum ggml_type        type,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2,
+            int64_t               ne3,
+            struct ggml_tensor ** args,
+            int                   n_args,
+            ggml_custom_op_t      fun,
+            int                   n_tasks,
+            void                * userdata);
+
+    GGML_API struct ggml_tensor * ggml_custom_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor ** args,
+            int                   n_args,
+            ggml_custom_op_t      fun,
+            int                   n_tasks,
+            void                * userdata);
+
     // loss function
 
     GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
index 34618c27aa4753f98c789765e81272b58024cadf..50400328738efed580c5fc21564475846b905196 100644 (file)
@@ -2027,41 +2027,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_rwkv_wkv7(params, tensor);
             } break;
-        case GGML_OP_MAP_UNARY:
-            {
-                ggml_unary_op_f32_t fun;
-                memcpy(&fun, tensor->op_params, sizeof(fun));
-                ggml_compute_forward_map_unary(params, tensor, fun);
-            }
-            break;
-        case GGML_OP_MAP_BINARY:
-            {
-                ggml_binary_op_f32_t fun;
-                memcpy(&fun, tensor->op_params, sizeof(fun));
-                ggml_compute_forward_map_binary(params, tensor, fun);
-            }
-            break;
-        case GGML_OP_MAP_CUSTOM1_F32:
-            {
-                ggml_custom1_op_f32_t fun;
-                memcpy(&fun, tensor->op_params, sizeof(fun));
-                ggml_compute_forward_map_custom1_f32(params, tensor, fun);
-            }
-            break;
-        case GGML_OP_MAP_CUSTOM2_F32:
-            {
-                ggml_custom2_op_f32_t fun;
-                memcpy(&fun, tensor->op_params, sizeof(fun));
-                ggml_compute_forward_map_custom2_f32(params, tensor, fun);
-            }
-            break;
-        case GGML_OP_MAP_CUSTOM3_F32:
-            {
-                ggml_custom3_op_f32_t fun;
-                memcpy(&fun, tensor->op_params, sizeof(fun));
-                ggml_compute_forward_map_custom3_f32(params, tensor, fun);
-            }
-            break;
         case GGML_OP_MAP_CUSTOM1:
             {
                 ggml_compute_forward_map_custom1(params, tensor);
@@ -2077,6 +2042,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
                 ggml_compute_forward_map_custom3(params, tensor);
             }
             break;
+        case GGML_OP_CUSTOM:
+            {
+                ggml_compute_forward_custom(params, tensor);
+            }
+            break;
         case GGML_OP_CROSS_ENTROPY_LOSS:
             {
                 ggml_compute_forward_cross_entropy_loss(params, tensor);
@@ -2328,11 +2298,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_WIN_PART:
         case GGML_OP_WIN_UNPART:
         case GGML_OP_GET_REL_POS:
-        case GGML_OP_MAP_UNARY:
-        case GGML_OP_MAP_BINARY:
-        case GGML_OP_MAP_CUSTOM1_F32:
-        case GGML_OP_MAP_CUSTOM2_F32:
-        case GGML_OP_MAP_CUSTOM3_F32:
             {
                 n_tasks = 1;
             } break;
@@ -2366,6 +2331,16 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                     n_tasks = MIN(p.n_tasks, n_threads);
                 }
             } break;
+        case GGML_OP_CUSTOM:
+            {
+                struct ggml_custom_op_params p;
+                memcpy(&p, node->op_params, sizeof(p));
+                if (p.n_tasks == GGML_N_TASKS_MAX) {
+                    n_tasks = n_threads;
+                } else {
+                    n_tasks = MIN(p.n_tasks, n_threads);
+                }
+            } break;
         case GGML_OP_CROSS_ENTROPY_LOSS:
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
         case GGML_OP_OPT_STEP_ADAMW:
index f63656be54f5c3c69b604a019c4d2dd7f5c09f62..36b98152e0887dc365122263cb3d90c67129477f 100644 (file)
@@ -8268,152 +8268,6 @@ void ggml_compute_forward_rwkv_wkv7(
     }
 }
 
-// ggml_compute_forward_map_unary
-
-static void ggml_compute_forward_map_unary_f32(
-        const ggml_compute_params * params,
-        ggml_tensor * dst,
-        const ggml_unary_op_f32_t fun) {
-
-    const ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        fun(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-void ggml_compute_forward_map_unary(
-        const ggml_compute_params * params,
-        ggml_tensor * dst,
-        const ggml_unary_op_f32_t fun) {
-
-    const ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_map_unary_f32(params, dst, fun);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_map_binary
-
-static void ggml_compute_forward_map_binary_f32(
-        const ggml_compute_params * params,
-        ggml_tensor * dst,
-        const ggml_binary_op_f32_t fun) {
-
-    const ggml_tensor * src0 = dst->src[0];
-    const ggml_tensor * src1 = dst->src[1];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(src1));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        fun(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])),
-                (float *) ((char *) src1->data + i*(src1->nb[1])));
-    }
-}
-
-void ggml_compute_forward_map_binary(
-        const ggml_compute_params * params,
-        ggml_tensor * dst,
-        const ggml_binary_op_f32_t fun) {
-
-    const ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_map_binary_f32(params, dst, fun);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_map_custom1
-
-void ggml_compute_forward_map_custom1_f32(
-        const ggml_compute_params * params,
-        ggml_tensor * dst,
-        const ggml_custom1_op_f32_t fun) {
-
-    const ggml_tensor * a = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    fun(dst, a);
-}
-
-// ggml_compute_forward_map_custom2
-
-void ggml_compute_forward_map_custom2_f32(
-        const ggml_compute_params * params,
-        ggml_tensor * dst,
-        const ggml_custom2_op_f32_t fun) {
-
-    const ggml_tensor * a = dst->src[0];
-    const ggml_tensor * b = dst->src[1];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    fun(dst, a, b);
-}
-
-// ggml_compute_forward_map_custom3
-
-void ggml_compute_forward_map_custom3_f32(
-        const ggml_compute_params * params,
-        ggml_tensor * dst,
-        const ggml_custom3_op_f32_t fun) {
-
-    const ggml_tensor * a = dst->src[0];
-    const ggml_tensor * b = dst->src[1];
-    const ggml_tensor * c = dst->src[1];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    fun(dst, a, b, c);
-}
-
 // ggml_compute_forward_map_custom1
 
 void ggml_compute_forward_map_custom1(
@@ -8459,6 +8313,18 @@ void ggml_compute_forward_map_custom3(
     p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
 }
 
+// ggml_compute_forward_custom
+
+void ggml_compute_forward_custom(
+    const struct ggml_compute_params * params,
+          struct ggml_tensor * dst) {
+
+    struct ggml_custom_op_params p;
+    memcpy(&p, dst->op_params, sizeof(p));
+
+    p.fun(dst, params->ith, params->nth, p.userdata);
+}
+
 // ggml_compute_forward_cross_entropy_loss
 
 static void ggml_compute_forward_cross_entropy_loss_f32(
index d43fbc1fc472ac29f86652fbed03e159810b7728..410a372047a01032fc974ed4fb3fe2368a13d125 100644 (file)
@@ -96,29 +96,10 @@ void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params,
 void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
-void ggml_compute_forward_map_unary(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst,
-    const ggml_unary_op_f32_t fun);
-void ggml_compute_forward_map_binary(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst,
-    const ggml_binary_op_f32_t fun);
-void ggml_compute_forward_map_custom1_f32(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst,
-    const ggml_custom1_op_f32_t fun);
-void ggml_compute_forward_map_custom2_f32(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst,
-    const ggml_custom2_op_f32_t fun);
-void ggml_compute_forward_map_custom3_f32(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst,
-    const ggml_custom3_op_f32_t fun);
 void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_custom(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
index caa6b9dba3f066e9da553a89119a1d0840628a40..13a3857cf5b61e84e035d728a6245c3ee0ca78fa 100644 (file)
@@ -140,8 +140,14 @@ struct ggml_map_custom2_op_params {
 
 struct ggml_map_custom3_op_params {
     ggml_custom3_op_t fun;
-    int n_tasks;
-    void * userdata;
+    int               n_tasks;
+    void            * userdata;
+};
+
+struct ggml_custom_op_params {
+    ggml_custom_op_t fun;
+    int              n_tasks;
+    void           * userdata;
 };
 
 // bitset
index 3e274d6ae39614b106270de27a4b158c95db072a..98a0f61642be5aa8c8db68873080383635021c57 100644 (file)
@@ -982,23 +982,18 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
 
     "UNARY",
 
-    "MAP_UNARY",
-    "MAP_BINARY",
-
-    "MAP_CUSTOM1_F32",
-    "MAP_CUSTOM2_F32",
-    "MAP_CUSTOM3_F32",
-
     "MAP_CUSTOM1",
     "MAP_CUSTOM2",
     "MAP_CUSTOM3",
 
+    "CUSTOM",
+
     "CROSS_ENTROPY_LOSS",
     "CROSS_ENTROPY_LOSS_BACK",
     "OPT_STEP_ADAMW",
 };
 
-static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
+static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1081,23 +1076,18 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
 
     "unary(x)",
 
-    "f(x)",
-    "f(x,y)",
-
-    "custom_f32(x)",
-    "custom_f32(x,y)",
-    "custom_f32(x,y,z)",
+    "map_custom(x)",
+    "map_custom(x,y)",
+    "map_custom(x,y,z)",
 
     "custom(x)",
-    "custom(x,y)",
-    "custom(x,y,z)",
 
     "cross_entropy_loss(x,y)",
     "cross_entropy_loss_back(x,y)",
     "adamw(x)",
 };
 
-static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
+static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -4842,179 +4832,6 @@ struct ggml_tensor * ggml_unary_inplace(
     return ggml_unary_impl(ctx, a, op, true);
 }
 
-// ggml_map_unary
-
-static struct ggml_tensor * ggml_map_unary_impl_f32(
-        struct ggml_context        * ctx,
-        struct ggml_tensor         * a,
-        const  ggml_unary_op_f32_t   fun,
-        bool                         inplace) {
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
-
-    result->op     = GGML_OP_MAP_UNARY;
-    result->src[0] = a;
-
-    return result;
-}
-
-struct ggml_tensor * ggml_map_unary_f32(
-        struct ggml_context        * ctx,
-        struct ggml_tensor         * a,
-        const  ggml_unary_op_f32_t   fun) {
-    return ggml_map_unary_impl_f32(ctx, a, fun, false);
-}
-
-struct ggml_tensor * ggml_map_unary_inplace_f32(
-        struct ggml_context        * ctx,
-        struct ggml_tensor         * a,
-        const  ggml_unary_op_f32_t   fun) {
-    return ggml_map_unary_impl_f32(ctx, a, fun, true);
-}
-
-// ggml_map_binary
-
-static struct ggml_tensor * ggml_map_binary_impl_f32(
-        struct ggml_context         * ctx,
-        struct ggml_tensor          * a,
-        struct ggml_tensor          * b,
-        const  ggml_binary_op_f32_t   fun,
-        bool                          inplace) {
-    GGML_ASSERT(ggml_are_same_shape(a, b));
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
-
-    result->op     = GGML_OP_MAP_BINARY;
-    result->src[0] = a;
-    result->src[1] = b;
-
-    return result;
-}
-
-struct ggml_tensor * ggml_map_binary_f32(
-        struct ggml_context         * ctx,
-        struct ggml_tensor          * a,
-        struct ggml_tensor          * b,
-        const  ggml_binary_op_f32_t   fun) {
-    return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
-}
-
-struct ggml_tensor * ggml_map_binary_inplace_f32(
-        struct ggml_context         * ctx,
-        struct ggml_tensor          * a,
-        struct ggml_tensor          * b,
-        const  ggml_binary_op_f32_t   fun) {
-    return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
-}
-
-// ggml_map_custom1_f32
-
-static struct ggml_tensor * ggml_map_custom1_impl_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        const  ggml_custom1_op_f32_t   fun,
-        bool                           inplace) {
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
-
-    result->op     = GGML_OP_MAP_CUSTOM1_F32;
-    result->src[0] = a;
-
-    return result;
-}
-
-struct ggml_tensor * ggml_map_custom1_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        const  ggml_custom1_op_f32_t   fun) {
-    return ggml_map_custom1_impl_f32(ctx, a, fun, false);
-}
-
-struct ggml_tensor * ggml_map_custom1_inplace_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        const  ggml_custom1_op_f32_t   fun) {
-    return ggml_map_custom1_impl_f32(ctx, a, fun, true);
-}
-
-// ggml_map_custom2_f32
-
-static struct ggml_tensor * ggml_map_custom2_impl_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        const  ggml_custom2_op_f32_t   fun,
-        bool                           inplace) {
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
-
-    result->op     = GGML_OP_MAP_CUSTOM2_F32;
-    result->src[0] = a;
-    result->src[1] = b;
-
-    return result;
-}
-
-struct ggml_tensor * ggml_map_custom2_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        const  ggml_custom2_op_f32_t   fun) {
-    return ggml_map_custom2_impl_f32(ctx, a, b, fun, false);
-}
-
-struct ggml_tensor * ggml_map_custom2_inplace_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        const  ggml_custom2_op_f32_t   fun) {
-    return ggml_map_custom2_impl_f32(ctx, a, b, fun, true);
-}
-
-// ggml_map_custom3_f32
-
-static struct ggml_tensor * ggml_map_custom3_impl_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        struct ggml_tensor           * c,
-        const  ggml_custom3_op_f32_t   fun,
-        bool                           inplace) {
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
-
-    result->op     = GGML_OP_MAP_CUSTOM3_F32;
-    result->src[0] = a;
-    result->src[1] = b;
-    result->src[2] = c;
-
-    return result;
-}
-
-struct ggml_tensor * ggml_map_custom3_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        struct ggml_tensor           * c,
-        const  ggml_custom3_op_f32_t   fun) {
-    return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false);
-}
-
-struct ggml_tensor * ggml_map_custom3_inplace_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        struct ggml_tensor           * c,
-        const  ggml_custom3_op_f32_t   fun) {
-    return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true);
-}
-
 // ggml_map_custom1
 
 static struct ggml_tensor * ggml_map_custom1_impl(
@@ -5033,7 +4850,7 @@ static struct ggml_tensor * ggml_map_custom1_impl(
         /*.n_tasks  =*/ n_tasks,
         /*.userdata =*/ userdata
     };
-    ggml_set_op_params(result, (const void *) &params, sizeof(params));
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op     = GGML_OP_MAP_CUSTOM1;
     result->src[0] = a;
@@ -5078,7 +4895,7 @@ static struct ggml_tensor * ggml_map_custom2_impl(
         /*.n_tasks  =*/ n_tasks,
         /*.userdata =*/ userdata
     };
-    ggml_set_op_params(result, (const void *) &params, sizeof(params));
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op     = GGML_OP_MAP_CUSTOM2;
     result->src[0] = a;
@@ -5127,7 +4944,7 @@ static struct ggml_tensor * ggml_map_custom3_impl(
         /*.n_tasks  =*/ n_tasks,
         /*.userdata =*/ userdata
     };
-    ggml_set_op_params(result, (const void *) &params, sizeof(params));
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op     = GGML_OP_MAP_CUSTOM3;
     result->src[0] = a;
@@ -5159,6 +4976,66 @@ struct ggml_tensor * ggml_map_custom3_inplace(
     return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
 }
 
+struct ggml_tensor * ggml_custom_4d(
+        struct ggml_context * ctx,
+        enum ggml_type        type,
+        int64_t               ne0,
+        int64_t               ne1,
+        int64_t               ne2,
+        int64_t               ne3,
+        struct ggml_tensor ** args,
+        int                   n_args,
+        ggml_custom_op_t      fun,
+        int                   n_tasks,
+        void                * userdata) {
+
+    GGML_ASSERT(n_args < GGML_MAX_SRC);
+
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
+
+    struct ggml_custom_op_params params = {
+        /*.fun      =*/ fun,
+        /*.n_tasks  =*/ n_tasks,
+        /*.userdata =*/ userdata
+    };
+    ggml_set_op_params(result, &params, sizeof(params));
+
+    result->op = GGML_OP_CUSTOM;
+    for (int i = 0; i < n_args; i++) {
+        result->src[i] = args[i];
+    }
+
+    return result;
+}
+
+struct ggml_tensor * ggml_custom_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor ** args,
+        int                   n_args,
+        ggml_custom_op_t      fun,
+        int                   n_tasks,
+        void                * userdata) {
+
+    GGML_ASSERT(n_args < GGML_MAX_SRC - 1);
+
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    struct ggml_custom_op_params params = {
+        /*.fun      =*/ fun,
+        /*.n_tasks  =*/ n_tasks,
+        /*.userdata =*/ userdata
+    };
+    ggml_set_op_params(result, &params, sizeof(params));
+
+    result->op = GGML_OP_CUSTOM;
+    result->src[0] = a;
+    for (int i = 0; i < n_args; i++) {
+        result->src[i + 1] = args[i];
+    }
+
+    return result;
+}
 // ggml_cross_entropy_loss
 
 struct ggml_tensor * ggml_cross_entropy_loss(