]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : add more generic custom op, remove deprecated custom ops (#1183)
authorDiego Devesa <redacted>
Wed, 9 Apr 2025 10:31:34 +0000 (12:31 +0200)
committerGitHub <redacted>
Wed, 9 Apr 2025 10:31:34 +0000 (12:31 +0200)
* ggml : add more generic ggml_custom op

* ggml : remove deprecated custom ops

include/ggml.h
src/ggml-cpu/ggml-cpu.c
src/ggml-cpu/ops.cpp
src/ggml-cpu/ops.h
src/ggml-impl.h
src/ggml.c
tests/test-customop.c

index 452c967b0a637e3da4905605d3657bab8525ba82..a5447ecdf686a93e51245a0dda6db96a4bd52588 100644 (file)
@@ -507,17 +507,12 @@ extern "C" {
 
         GGML_OP_UNARY,
 
-        GGML_OP_MAP_UNARY,
-        GGML_OP_MAP_BINARY,
-
-        GGML_OP_MAP_CUSTOM1_F32,
-        GGML_OP_MAP_CUSTOM2_F32,
-        GGML_OP_MAP_CUSTOM3_F32,
-
         GGML_OP_MAP_CUSTOM1,
         GGML_OP_MAP_CUSTOM2,
         GGML_OP_MAP_CUSTOM3,
 
+        GGML_OP_CUSTOM,
+
         GGML_OP_CROSS_ENTROPY_LOSS,
         GGML_OP_CROSS_ENTROPY_LOSS_BACK,
         GGML_OP_OPT_STEP_ADAMW,
@@ -1916,83 +1911,6 @@ extern "C" {
 
     // custom operators
 
-    typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
-    typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
-
-    typedef void (*ggml_custom1_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *);
-    typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
-    typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32(
-            struct ggml_context        * ctx,
-            struct ggml_tensor         * a,
-                   ggml_unary_op_f32_t   fun),
-        "use ggml_map_custom1 instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32(
-            struct ggml_context        * ctx,
-            struct ggml_tensor         * a,
-                   ggml_unary_op_f32_t   fun),
-        "use ggml_map_custom1_inplace instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32(
-            struct ggml_context         * ctx,
-            struct ggml_tensor          * a,
-            struct ggml_tensor          * b,
-                   ggml_binary_op_f32_t   fun),
-        "use ggml_map_custom2 instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32(
-            struct ggml_context         * ctx,
-            struct ggml_tensor          * a,
-            struct ggml_tensor          * b,
-                   ggml_binary_op_f32_t   fun),
-        "use ggml_map_custom2_inplace instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-                   ggml_custom1_op_f32_t   fun),
-        "use ggml_map_custom1 instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-                   ggml_custom1_op_f32_t   fun),
-        "use ggml_map_custom1_inplace instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-            struct ggml_tensor           * b,
-                   ggml_custom2_op_f32_t   fun),
-        "use ggml_map_custom2 instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-            struct ggml_tensor           * b,
-                   ggml_custom2_op_f32_t   fun),
-        "use ggml_map_custom2_inplace instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-            struct ggml_tensor           * b,
-            struct ggml_tensor           * c,
-                   ggml_custom3_op_f32_t   fun),
-        "use ggml_map_custom3 instead");
-
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32(
-            struct ggml_context          * ctx,
-            struct ggml_tensor           * a,
-            struct ggml_tensor           * b,
-            struct ggml_tensor           * c,
-                   ggml_custom3_op_f32_t   fun),
-        "use ggml_map_custom3_inplace instead");
-
-    // custom operators v2
-
     typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
     typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
     typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
@@ -2048,6 +1966,30 @@ extern "C" {
             int                     n_tasks,
             void                  * userdata);
 
+    typedef void (*ggml_custom_op_t)(struct ggml_tensor * dst , int ith, int nth, void * userdata);
+
+    GGML_API struct ggml_tensor * ggml_custom_4d(
+            struct ggml_context * ctx,
+            enum ggml_type        type,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2,
+            int64_t               ne3,
+            struct ggml_tensor ** args,
+            int                   n_args,
+            ggml_custom_op_t      fun,
+            int                   n_tasks,
+            void                * userdata);
+
+    GGML_API struct ggml_tensor * ggml_custom_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor ** args,
+            int                   n_args,
+            ggml_custom_op_t      fun,
+            int                   n_tasks,
+            void                * userdata);
+
     // loss function
 
     GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
index 34618c27aa4753f98c789765e81272b58024cadf..50400328738efed580c5fc21564475846b905196 100644 (file)
@@ -2027,41 +2027,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_rwkv_wkv7(params, tensor);
             } break;
-        case GGML_OP_MAP_UNARY:
-            {
-                ggml_unary_op_f32_t fun;
-                memcpy(&fun, tensor->op_params, sizeof(fun));
-                ggml_compute_forward_map_unary(params, tensor, fun);
-            }
-            break;
-        case GGML_OP_MAP_BINARY:
-            {
-                ggml_binary_op_f32_t fun;
-                memcpy(&fun, tensor->op_params, sizeof(fun));
-                ggml_compute_forward_map_binary(params, tensor, fun);
-            }
-            break;
-        case GGML_OP_MAP_CUSTOM1_F32:
-            {
-                ggml_custom1_op_f32_t fun;
-                memcpy(&fun, tensor->op_params, sizeof(fun));
-                ggml_compute_forward_map_custom1_f32(params, tensor, fun);
-            }
-            break;
-        case GGML_OP_MAP_CUSTOM2_F32:
-            {
-                ggml_custom2_op_f32_t fun;
-                memcpy(&fun, tensor->op_params, sizeof(fun));
-                ggml_compute_forward_map_custom2_f32(params, tensor, fun);
-            }
-            break;
-        case GGML_OP_MAP_CUSTOM3_F32:
-            {
-                ggml_custom3_op_f32_t fun;
-                memcpy(&fun, tensor->op_params, sizeof(fun));
-                ggml_compute_forward_map_custom3_f32(params, tensor, fun);
-            }
-            break;
         case GGML_OP_MAP_CUSTOM1:
             {
                 ggml_compute_forward_map_custom1(params, tensor);
@@ -2077,6 +2042,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
                 ggml_compute_forward_map_custom3(params, tensor);
             }
             break;
+        case GGML_OP_CUSTOM:
+            {
+                ggml_compute_forward_custom(params, tensor);
+            }
+            break;
         case GGML_OP_CROSS_ENTROPY_LOSS:
             {
                 ggml_compute_forward_cross_entropy_loss(params, tensor);
@@ -2328,11 +2298,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_WIN_PART:
         case GGML_OP_WIN_UNPART:
         case GGML_OP_GET_REL_POS:
-        case GGML_OP_MAP_UNARY:
-        case GGML_OP_MAP_BINARY:
-        case GGML_OP_MAP_CUSTOM1_F32:
-        case GGML_OP_MAP_CUSTOM2_F32:
-        case GGML_OP_MAP_CUSTOM3_F32:
             {
                 n_tasks = 1;
             } break;
@@ -2366,6 +2331,16 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                     n_tasks = MIN(p.n_tasks, n_threads);
                 }
             } break;
+        case GGML_OP_CUSTOM:
+            {
+                struct ggml_custom_op_params p;
+                memcpy(&p, node->op_params, sizeof(p));
+                if (p.n_tasks == GGML_N_TASKS_MAX) {
+                    n_tasks = n_threads;
+                } else {
+                    n_tasks = MIN(p.n_tasks, n_threads);
+                }
+            } break;
         case GGML_OP_CROSS_ENTROPY_LOSS:
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
         case GGML_OP_OPT_STEP_ADAMW:
index 7a8d5ac6fd9d0d36879f7e936b0b9ee806f69ba4..0aa9d09a501a7fb6eb46f65a9eb9bb691a37cfee 100644 (file)
@@ -8264,152 +8264,6 @@ void ggml_compute_forward_rwkv_wkv7(
     }
 }
 
-// ggml_compute_forward_map_unary
-
-static void ggml_compute_forward_map_unary_f32(
-        const ggml_compute_params * params,
-        ggml_tensor * dst,
-        const ggml_unary_op_f32_t fun) {
-
-    const ggml_tensor * src0 = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        fun(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
-    }
-}
-
-void ggml_compute_forward_map_unary(
-        const ggml_compute_params * params,
-        ggml_tensor * dst,
-        const ggml_unary_op_f32_t fun) {
-
-    const ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_map_unary_f32(params, dst, fun);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_map_binary
-
-static void ggml_compute_forward_map_binary_f32(
-        const ggml_compute_params * params,
-        ggml_tensor * dst,
-        const ggml_binary_op_f32_t fun) {
-
-    const ggml_tensor * src0 = dst->src[0];
-    const ggml_tensor * src1 = dst->src[1];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(src1));
-    assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
-
-    const int n  = ggml_nrows(src0);
-    const int nc = src0->ne[0];
-
-    for (int i = 0; i < n; i++) {
-        fun(nc,
-                (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])),
-                (float *) ((char *) src1->data + i*(src1->nb[1])));
-    }
-}
-
-void ggml_compute_forward_map_binary(
-        const ggml_compute_params * params,
-        ggml_tensor * dst,
-        const ggml_binary_op_f32_t fun) {
-
-    const ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_map_binary_f32(params, dst, fun);
-            } break;
-        default:
-            {
-                GGML_ABORT("fatal error");
-            }
-    }
-}
-
-// ggml_compute_forward_map_custom1
-
-void ggml_compute_forward_map_custom1_f32(
-        const ggml_compute_params * params,
-        ggml_tensor * dst,
-        const ggml_custom1_op_f32_t fun) {
-
-    const ggml_tensor * a = dst->src[0];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    fun(dst, a);
-}
-
-// ggml_compute_forward_map_custom2
-
-void ggml_compute_forward_map_custom2_f32(
-        const ggml_compute_params * params,
-        ggml_tensor * dst,
-        const ggml_custom2_op_f32_t fun) {
-
-    const ggml_tensor * a = dst->src[0];
-    const ggml_tensor * b = dst->src[1];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    fun(dst, a, b);
-}
-
-// ggml_compute_forward_map_custom3
-
-void ggml_compute_forward_map_custom3_f32(
-        const ggml_compute_params * params,
-        ggml_tensor * dst,
-        const ggml_custom3_op_f32_t fun) {
-
-    const ggml_tensor * a = dst->src[0];
-    const ggml_tensor * b = dst->src[1];
-    const ggml_tensor * c = dst->src[1];
-
-    if (params->ith != 0) {
-        return;
-    }
-
-    fun(dst, a, b, c);
-}
-
 // ggml_compute_forward_map_custom1
 
 void ggml_compute_forward_map_custom1(
@@ -8455,6 +8309,18 @@ void ggml_compute_forward_map_custom3(
     p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
 }
 
+// ggml_compute_forward_custom
+
+void ggml_compute_forward_custom(
+    const struct ggml_compute_params * params,
+          struct ggml_tensor * dst) {
+
+    struct ggml_custom_op_params p;
+    memcpy(&p, dst->op_params, sizeof(p));
+
+    p.fun(dst, params->ith, params->nth, p.userdata);
+}
+
 // ggml_compute_forward_cross_entropy_loss
 
 static void ggml_compute_forward_cross_entropy_loss_f32(
index d43fbc1fc472ac29f86652fbed03e159810b7728..410a372047a01032fc974ed4fb3fe2368a13d125 100644 (file)
@@ -96,29 +96,10 @@ void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params,
 void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
-void ggml_compute_forward_map_unary(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst,
-    const ggml_unary_op_f32_t fun);
-void ggml_compute_forward_map_binary(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst,
-    const ggml_binary_op_f32_t fun);
-void ggml_compute_forward_map_custom1_f32(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst,
-    const ggml_custom1_op_f32_t fun);
-void ggml_compute_forward_map_custom2_f32(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst,
-    const ggml_custom2_op_f32_t fun);
-void ggml_compute_forward_map_custom3_f32(
-    const struct ggml_compute_params * params,
-    struct ggml_tensor * dst,
-    const ggml_custom3_op_f32_t fun);
 void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_custom(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
index 606175fb9241a6607f70004926771c32e736d905..7f94d776286d7e82a67359786e5a1670408e3693 100644 (file)
@@ -140,8 +140,14 @@ struct ggml_map_custom2_op_params {
 
 struct ggml_map_custom3_op_params {
     ggml_custom3_op_t fun;
-    int n_tasks;
-    void * userdata;
+    int               n_tasks;
+    void            * userdata;
+};
+
+struct ggml_custom_op_params {
+    ggml_custom_op_t fun;
+    int              n_tasks;
+    void           * userdata;
 };
 
 // bitset
index 3e274d6ae39614b106270de27a4b158c95db072a..98a0f61642be5aa8c8db68873080383635021c57 100644 (file)
@@ -982,23 +982,18 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
 
     "UNARY",
 
-    "MAP_UNARY",
-    "MAP_BINARY",
-
-    "MAP_CUSTOM1_F32",
-    "MAP_CUSTOM2_F32",
-    "MAP_CUSTOM3_F32",
-
     "MAP_CUSTOM1",
     "MAP_CUSTOM2",
     "MAP_CUSTOM3",
 
+    "CUSTOM",
+
     "CROSS_ENTROPY_LOSS",
     "CROSS_ENTROPY_LOSS_BACK",
     "OPT_STEP_ADAMW",
 };
 
-static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
+static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1081,23 +1076,18 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
 
     "unary(x)",
 
-    "f(x)",
-    "f(x,y)",
-
-    "custom_f32(x)",
-    "custom_f32(x,y)",
-    "custom_f32(x,y,z)",
+    "map_custom(x)",
+    "map_custom(x,y)",
+    "map_custom(x,y,z)",
 
     "custom(x)",
-    "custom(x,y)",
-    "custom(x,y,z)",
 
     "cross_entropy_loss(x,y)",
     "cross_entropy_loss_back(x,y)",
     "adamw(x)",
 };
 
-static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
+static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -4842,179 +4832,6 @@ struct ggml_tensor * ggml_unary_inplace(
     return ggml_unary_impl(ctx, a, op, true);
 }
 
-// ggml_map_unary
-
-static struct ggml_tensor * ggml_map_unary_impl_f32(
-        struct ggml_context        * ctx,
-        struct ggml_tensor         * a,
-        const  ggml_unary_op_f32_t   fun,
-        bool                         inplace) {
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
-
-    result->op     = GGML_OP_MAP_UNARY;
-    result->src[0] = a;
-
-    return result;
-}
-
-struct ggml_tensor * ggml_map_unary_f32(
-        struct ggml_context        * ctx,
-        struct ggml_tensor         * a,
-        const  ggml_unary_op_f32_t   fun) {
-    return ggml_map_unary_impl_f32(ctx, a, fun, false);
-}
-
-struct ggml_tensor * ggml_map_unary_inplace_f32(
-        struct ggml_context        * ctx,
-        struct ggml_tensor         * a,
-        const  ggml_unary_op_f32_t   fun) {
-    return ggml_map_unary_impl_f32(ctx, a, fun, true);
-}
-
-// ggml_map_binary
-
-static struct ggml_tensor * ggml_map_binary_impl_f32(
-        struct ggml_context         * ctx,
-        struct ggml_tensor          * a,
-        struct ggml_tensor          * b,
-        const  ggml_binary_op_f32_t   fun,
-        bool                          inplace) {
-    GGML_ASSERT(ggml_are_same_shape(a, b));
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
-
-    result->op     = GGML_OP_MAP_BINARY;
-    result->src[0] = a;
-    result->src[1] = b;
-
-    return result;
-}
-
-struct ggml_tensor * ggml_map_binary_f32(
-        struct ggml_context         * ctx,
-        struct ggml_tensor          * a,
-        struct ggml_tensor          * b,
-        const  ggml_binary_op_f32_t   fun) {
-    return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
-}
-
-struct ggml_tensor * ggml_map_binary_inplace_f32(
-        struct ggml_context         * ctx,
-        struct ggml_tensor          * a,
-        struct ggml_tensor          * b,
-        const  ggml_binary_op_f32_t   fun) {
-    return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
-}
-
-// ggml_map_custom1_f32
-
-static struct ggml_tensor * ggml_map_custom1_impl_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        const  ggml_custom1_op_f32_t   fun,
-        bool                           inplace) {
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
-
-    result->op     = GGML_OP_MAP_CUSTOM1_F32;
-    result->src[0] = a;
-
-    return result;
-}
-
-struct ggml_tensor * ggml_map_custom1_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        const  ggml_custom1_op_f32_t   fun) {
-    return ggml_map_custom1_impl_f32(ctx, a, fun, false);
-}
-
-struct ggml_tensor * ggml_map_custom1_inplace_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        const  ggml_custom1_op_f32_t   fun) {
-    return ggml_map_custom1_impl_f32(ctx, a, fun, true);
-}
-
-// ggml_map_custom2_f32
-
-static struct ggml_tensor * ggml_map_custom2_impl_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        const  ggml_custom2_op_f32_t   fun,
-        bool                           inplace) {
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
-
-    result->op     = GGML_OP_MAP_CUSTOM2_F32;
-    result->src[0] = a;
-    result->src[1] = b;
-
-    return result;
-}
-
-struct ggml_tensor * ggml_map_custom2_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        const  ggml_custom2_op_f32_t   fun) {
-    return ggml_map_custom2_impl_f32(ctx, a, b, fun, false);
-}
-
-struct ggml_tensor * ggml_map_custom2_inplace_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        const  ggml_custom2_op_f32_t   fun) {
-    return ggml_map_custom2_impl_f32(ctx, a, b, fun, true);
-}
-
-// ggml_map_custom3_f32
-
-static struct ggml_tensor * ggml_map_custom3_impl_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        struct ggml_tensor           * c,
-        const  ggml_custom3_op_f32_t   fun,
-        bool                           inplace) {
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
-
-    result->op     = GGML_OP_MAP_CUSTOM3_F32;
-    result->src[0] = a;
-    result->src[1] = b;
-    result->src[2] = c;
-
-    return result;
-}
-
-struct ggml_tensor * ggml_map_custom3_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        struct ggml_tensor           * c,
-        const  ggml_custom3_op_f32_t   fun) {
-    return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false);
-}
-
-struct ggml_tensor * ggml_map_custom3_inplace_f32(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        struct ggml_tensor           * c,
-        const  ggml_custom3_op_f32_t   fun) {
-    return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true);
-}
-
 // ggml_map_custom1
 
 static struct ggml_tensor * ggml_map_custom1_impl(
@@ -5033,7 +4850,7 @@ static struct ggml_tensor * ggml_map_custom1_impl(
         /*.n_tasks  =*/ n_tasks,
         /*.userdata =*/ userdata
     };
-    ggml_set_op_params(result, (const void *) &params, sizeof(params));
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op     = GGML_OP_MAP_CUSTOM1;
     result->src[0] = a;
@@ -5078,7 +4895,7 @@ static struct ggml_tensor * ggml_map_custom2_impl(
         /*.n_tasks  =*/ n_tasks,
         /*.userdata =*/ userdata
     };
-    ggml_set_op_params(result, (const void *) &params, sizeof(params));
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op     = GGML_OP_MAP_CUSTOM2;
     result->src[0] = a;
@@ -5127,7 +4944,7 @@ static struct ggml_tensor * ggml_map_custom3_impl(
         /*.n_tasks  =*/ n_tasks,
         /*.userdata =*/ userdata
     };
-    ggml_set_op_params(result, (const void *) &params, sizeof(params));
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op     = GGML_OP_MAP_CUSTOM3;
     result->src[0] = a;
@@ -5159,6 +4976,66 @@ struct ggml_tensor * ggml_map_custom3_inplace(
     return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
 }
 
+struct ggml_tensor * ggml_custom_4d(
+        struct ggml_context * ctx,
+        enum ggml_type        type,
+        int64_t               ne0,
+        int64_t               ne1,
+        int64_t               ne2,
+        int64_t               ne3,
+        struct ggml_tensor ** args,
+        int                   n_args,
+        ggml_custom_op_t      fun,
+        int                   n_tasks,
+        void                * userdata) {
+
+    GGML_ASSERT(n_args < GGML_MAX_SRC);
+
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
+
+    struct ggml_custom_op_params params = {
+        /*.fun      =*/ fun,
+        /*.n_tasks  =*/ n_tasks,
+        /*.userdata =*/ userdata
+    };
+    ggml_set_op_params(result, &params, sizeof(params));
+
+    result->op = GGML_OP_CUSTOM;
+    for (int i = 0; i < n_args; i++) {
+        result->src[i] = args[i];
+    }
+
+    return result;
+}
+
+struct ggml_tensor * ggml_custom_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor ** args,
+        int                   n_args,
+        ggml_custom_op_t      fun,
+        int                   n_tasks,
+        void                * userdata) {
+
+    GGML_ASSERT(n_args < GGML_MAX_SRC - 1);
+
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    struct ggml_custom_op_params params = {
+        /*.fun      =*/ fun,
+        /*.n_tasks  =*/ n_tasks,
+        /*.userdata =*/ userdata
+    };
+    ggml_set_op_params(result, &params, sizeof(params));
+
+    result->op = GGML_OP_CUSTOM;
+    result->src[0] = a;
+    for (int i = 0; i < n_args; i++) {
+        result->src[i + 1] = args[i];
+    }
+
+    return result;
+}
 // ggml_cross_entropy_loss
 
 struct ggml_tensor * ggml_cross_entropy_loss(
index c85cafe1b1e8990972506cab14fd4a4066919e3b..2428da0f549d10bfcd15285a5902e44297f47f27 100644 (file)
@@ -4,7 +4,6 @@
 #include <string.h>
 #include <stdio.h>
 #include <stdlib.h>
-#include <assert.h>
 
 #if defined(_WIN32)
 #include <windows.h>
@@ -36,8 +35,8 @@ atomic_int g_custom3_count = 0;
 
 void custom1(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata) {
     // check that the userdata is correct
-    assert(userdata == NULL);
-    assert(ggml_are_same_shape(dst, a));
+    GGML_ASSERT(userdata == NULL);
+    GGML_ASSERT(ggml_are_same_shape(dst, a));
 
     atomic_fetch_add(&g_custom1_count, 1);
 
@@ -45,8 +44,8 @@ void custom1(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, i
     float * dst_data = ggml_get_data_f32(dst);
 
     // this assumes that the tensors are contiguous
-    assert(ggml_is_contiguous(dst));
-    assert(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_is_contiguous(a));
 
     // parallelize by elements
     const int ne = (int)ggml_nelements(dst);
@@ -61,10 +60,10 @@ void custom1(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, i
 
 void custom2(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata) {
     // check that the userdata is correct
-    assert(userdata == g_userdata);
-    assert(strcmp(userdata, "ggml") == 0);
-    assert(ggml_are_same_shape(dst, a));
-    assert(ggml_are_same_shape(dst, b));
+    GGML_ASSERT(userdata == g_userdata);
+    GGML_ASSERT(strcmp(userdata, "ggml") == 0);
+    GGML_ASSERT(ggml_are_same_shape(dst, a));
+    GGML_ASSERT(ggml_are_same_shape(dst, b));
 
     atomic_fetch_add(&g_custom2_count, 1);
 
@@ -84,9 +83,9 @@ void custom2(struct ggml_tensor * dst , const struct ggml_tensor * a, const stru
     const int nc = (int)dst->ne[0];
 
     // this assumes that the tensors are contiguous
-    assert(ggml_is_contiguous(dst));
-    assert(ggml_is_contiguous(a));
-    assert(ggml_is_contiguous(b));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_is_contiguous(b));
 
     for (int ir = ir0; ir < ir1; ++ir) {
         for (int ic = 0; ic < nc; ++ic) {
@@ -98,11 +97,11 @@ void custom2(struct ggml_tensor * dst , const struct ggml_tensor * a, const stru
 
 void custom3(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata) {
     // check that the userdata is correct
-    assert(userdata == g_userdata);
-    assert(strcmp(userdata, "ggml") == 0);
-    assert(ggml_are_same_shape(dst, a));
-    assert(ggml_are_same_shape(dst, b));
-    assert(ggml_are_same_shape(dst, c));
+    GGML_ASSERT(userdata == g_userdata);
+    GGML_ASSERT(strcmp(userdata, "ggml") == 0);
+    GGML_ASSERT(ggml_are_same_shape(dst, a));
+    GGML_ASSERT(ggml_are_same_shape(dst, b));
+    GGML_ASSERT(ggml_are_same_shape(dst, c));
 
     atomic_fetch_add(&g_custom3_count, 1);
 
@@ -112,22 +111,61 @@ void custom3(struct ggml_tensor * dst , const struct ggml_tensor * a, const stru
     float * dst_data = ggml_get_data_f32(dst);
 
     // dont parallelize
-    assert(ith == 0);
+    GGML_ASSERT(ith == 0);
 
     // number of elements
     const int ne = (int)ggml_nelements(dst);
 
     // this assumes that the tensors are contiguous
-    assert(ggml_is_contiguous(dst));
-    assert(ggml_is_contiguous(a));
-    assert(ggml_is_contiguous(b));
-    assert(ggml_is_contiguous(c));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_is_contiguous(b));
+    GGML_ASSERT(ggml_is_contiguous(c));
 
     for (int i = 0; i < ne; ++i) {
         dst_data[i] = a_data[i] + b_data[i] + c_data[i];
     }
 }
 
+void custom(struct ggml_tensor * dst, int ith, int nth, void * userdata) {
+    struct ggml_tensor * src0 = dst->src[0];
+    struct ggml_tensor * src1 = dst->src[1];
+    struct ggml_tensor * src2 = dst->src[2];
+    struct ggml_tensor * src3 = dst->src[3];
+    struct ggml_tensor * src4 = dst->src[4];
+
+    int32_t * dst_data = (int32_t *) ggml_get_data(dst);
+    const float * src0_data = ggml_get_data_f32(src0);
+    const float * src1_data = ggml_get_data_f32(src1);
+    const float * src2_data = ggml_get_data_f32(src2);
+    const float * src3_data = ggml_get_data_f32(src3);
+    const float * src4_data = ggml_get_data_f32(src4);
+
+    // check that the userdata is correct
+    GGML_ASSERT(userdata == g_userdata);
+    GGML_ASSERT(strcmp(userdata, "ggml") == 0);
+
+    // check that the tensors are contiguous
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(src2));
+    GGML_ASSERT(ggml_is_contiguous(src3));
+    GGML_ASSERT(ggml_is_contiguous(src4));
+
+    // check that the shapes are the same
+    GGML_ASSERT(ggml_are_same_shape(dst, src0));
+    GGML_ASSERT(ggml_are_same_shape(dst, src1));
+    GGML_ASSERT(ggml_are_same_shape(dst, src2));
+    GGML_ASSERT(ggml_are_same_shape(dst, src3));
+    GGML_ASSERT(ggml_are_same_shape(dst, src4));
+
+
+    for (int i = ith; i < ggml_nelements(dst); i += nth) {
+        dst_data[i] = src0_data[i] + src1_data[i] * src2_data[i] - src3_data[i] * src4_data[i];
+    }
+}
+
 int main(int argc, const char** argv) {
 
     float buf1_f32[1024];
@@ -160,9 +198,9 @@ int main(int argc, const char** argv) {
         const float * output = ggml_get_data_f32(m1);
 
         for (int i = 0; i < ggml_nelements(m1); ++i) {
-            assert(output[i] == buf1_f32[i] * 2);
+            GGML_ASSERT(output[i] == buf1_f32[i] * 2);
         }
-        assert(g_custom1_count == 2);
+        GGML_ASSERT(g_custom1_count == 2);
 
         ggml_free(ctx);
     }
@@ -172,8 +210,8 @@ int main(int argc, const char** argv) {
     {
         struct ggml_context * ctx = make_ctx();
         struct ggml_tensor * t1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);
-        memcpy(t1->data, buf1_f32, ggml_nbytes(t1));
         struct ggml_tensor * t2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);
+        memcpy(t1->data, buf1_f32, ggml_nbytes(t1));
         memcpy(t2->data, buf2_f32, ggml_nbytes(t2));
 
         struct ggml_tensor * m2 = ggml_map_custom2(ctx, t1, t2, custom2, GGML_N_TASKS_MAX, g_userdata);
@@ -186,10 +224,10 @@ int main(int argc, const char** argv) {
         const float * output = ggml_get_data_f32(m2);
 
         for (int i = 0; i < ggml_nelements(m2); ++i) {
-            assert(output[i] == buf1_f32[i] + buf2_f32[i]);
+            GGML_ASSERT(output[i] == buf1_f32[i] + buf2_f32[i]);
         }
 
-        assert(g_custom2_count == 4);
+        GGML_ASSERT(g_custom2_count == 4);
 
         ggml_free(ctx);
     }
@@ -199,10 +237,11 @@ int main(int argc, const char** argv) {
     {
         struct ggml_context * ctx = make_ctx();
         struct ggml_tensor * t1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);
-        memcpy(t1->data, buf1_f32, ggml_nbytes(t1));
         struct ggml_tensor * t2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);
-        memcpy(t2->data, buf2_f32, ggml_nbytes(t2));
         struct ggml_tensor * t3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);
+
+        memcpy(t1->data, buf1_f32, ggml_nbytes(t1));
+        memcpy(t2->data, buf2_f32, ggml_nbytes(t2));
         memcpy(t3->data, buf3_f32, ggml_nbytes(t3));
 
         struct ggml_tensor * m3 = ggml_map_custom3(ctx, t1, t2, t3, custom3, 1, g_userdata);
@@ -215,14 +254,47 @@ int main(int argc, const char** argv) {
         const float * output = ggml_get_data_f32(m3);
 
         for (int i = 0; i < ggml_nelements(m3); ++i) {
-            assert(output[i] == buf1_f32[i] + buf2_f32[i] + buf3_f32[i]);
+            GGML_ASSERT(output[i] == buf1_f32[i] + buf2_f32[i] + buf3_f32[i]);
         }
 
-        assert(g_custom3_count == 1);
+        GGML_ASSERT(g_custom3_count == 1);
 
         ggml_free(ctx);
     }
 
+    // custom
+    {
+        struct ggml_context * ctx = make_ctx();
+        struct ggml_tensor * t1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);
+        struct ggml_tensor * t2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);
+        struct ggml_tensor * t3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);
+        struct ggml_tensor * t4 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);
+        struct ggml_tensor * t5 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 10, 2);
+        memcpy(t1->data, buf1_f32, ggml_nbytes(t1));
+        memcpy(t2->data, buf2_f32, ggml_nbytes(t2));
+        memcpy(t3->data, buf3_f32, ggml_nbytes(t3));
+        memcpy(t4->data, buf1_f32, ggml_nbytes(t4));
+        memcpy(t5->data, buf2_f32, ggml_nbytes(t5));
+
+        struct ggml_tensor * args[] = {
+            t1, t2, t3, t4, t5,
+        };
+
+        struct ggml_tensor * m4 = ggml_custom_4d(ctx, GGML_TYPE_I32, 10, 2, 1, 1, args, sizeof(args)/sizeof(args[0]), custom, GGML_N_TASKS_MAX, g_userdata);
+
+        struct ggml_cgraph * graph = ggml_new_graph(ctx);
+        ggml_build_forward_expand(graph, m4);
+
+        ggml_graph_compute_with_ctx(ctx, graph, 4);
+
+        const int32_t * output = (const int32_t *) ggml_get_data(m4);
+
+        for (int i = 0; i < ggml_nelements(m4); ++i) {
+            GGML_ASSERT(output[i] == buf1_f32[i] + buf2_f32[i] * buf3_f32[i] - buf1_f32[i] * buf2_f32[i]);
+        }
+
+        ggml_free(ctx);
+    }
 
     return 0;
 }