]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : add custom mapping functions (#264)
authorLoganDark <redacted>
Sat, 24 Jun 2023 17:47:53 +0000 (10:47 -0700)
committerGitHub <redacted>
Sat, 24 Jun 2023 17:47:53 +0000 (20:47 +0300)
* Add custom mapping functions

The current mapping functions are basically jokes, add some real
ones. These ones get access to the actual tensor structs so they
can do things like

- Know the dimensions they are operating on
- Work with tensors with more than 2 dimensions, or transposed
- Operate on two differently sized tensors (like matmul)
- Use their own thread pool that does a better job than ggml does.

Among other things ...

* fix ordering mistake

* ggml : custom operators support scratch buffers

---------

Co-authored-by: Georgi Gerganov <redacted>
include/ggml/ggml.h
src/ggml.c

index 4b6b7284510f9f62aae8e698a4827d3b83fe037d..5ebd9c46c3d5267df4e2a879cfcb633bd23c6085 100644 (file)
@@ -345,6 +345,10 @@ extern "C" {
         GGML_OP_MAP_UNARY,
         GGML_OP_MAP_BINARY,
 
+        GGML_OP_MAP_CUSTOM1,
+        GGML_OP_MAP_CUSTOM2,
+        GGML_OP_MAP_CUSTOM3,
+
         GGML_OP_CROSS_ENTROPY_LOSS,
         GGML_OP_CROSS_ENTROPY_LOSS_BACK,
 
@@ -1167,21 +1171,73 @@ extern "C" {
             int                   h0,
             int                   w);
 
-    // Mapping operations
-    typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *);
+    // custom operators
+
+    typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
     typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
 
+    typedef void (*ggml_custom1_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *);
+    typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
+    typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
+
     GGML_API struct ggml_tensor * ggml_map_unary_f32(
             struct ggml_context        * ctx,
             struct ggml_tensor         * a,
                    ggml_unary_op_f32_t   fun);
 
+    GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32(
+            struct ggml_context        * ctx,
+            struct ggml_tensor         * a,
+                   ggml_unary_op_f32_t   fun);
+
     GGML_API struct ggml_tensor * ggml_map_binary_f32(
             struct ggml_context         * ctx,
             struct ggml_tensor          * a,
             struct ggml_tensor          * b,
                    ggml_binary_op_f32_t   fun);
 
+    GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32(
+            struct ggml_context         * ctx,
+            struct ggml_tensor          * a,
+            struct ggml_tensor          * b,
+                   ggml_binary_op_f32_t   fun);
+
+    GGML_API struct ggml_tensor * ggml_map_custom1_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+                   ggml_custom1_op_f32_t   fun);
+
+    GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+                   ggml_custom1_op_f32_t   fun);
+
+    GGML_API struct ggml_tensor * ggml_map_custom2_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+            struct ggml_tensor           * b,
+                   ggml_custom2_op_f32_t   fun);
+
+    GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+            struct ggml_tensor           * b,
+                   ggml_custom2_op_f32_t   fun);
+
+    GGML_API struct ggml_tensor * ggml_map_custom3_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+            struct ggml_tensor           * b,
+            struct ggml_tensor           * c,
+                   ggml_custom3_op_f32_t   fun);
+
+    GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+            struct ggml_tensor           * b,
+            struct ggml_tensor           * c,
+                   ggml_custom3_op_f32_t   fun);
+
     // loss function
 
     GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
index aa30e74c24455be8b915a6d8dc6d9429e9be5872..955f335cd18275174f2e3b338db013dc70aa7c2d 100644 (file)
@@ -3728,11 +3728,15 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "MAP_UNARY",
     "MAP_BINARY",
 
+    "MAP_CUSTOM1",
+    "MAP_CUSTOM2",
+    "MAP_CUSTOM3",
+
     "CROSS_ENTROPY_LOSS",
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 61, "GGML_OP_COUNT != 61");
+static_assert(GGML_OP_COUNT == 64, "GGML_OP_COUNT != 64");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -3800,11 +3804,15 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "f(x)",
     "f(x,y)",
 
+    "custom(x)",
+    "custom(x,y)",
+    "custom(x,y,z)",
+
     "cross_entropy_loss(x,y)",
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 61, "GGML_OP_COUNT != 61");
+static_assert(GGML_OP_COUNT == 64, "GGML_OP_COUNT != 64");
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -7109,9 +7117,14 @@ struct ggml_tensor * ggml_map_unary_impl_f32(
         is_node = true;
     }
 
+    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_save(ctx);
+
     struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
     *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
-    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_load(ctx);
 
     result->op = GGML_OP_MAP_UNARY;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -7151,9 +7164,14 @@ struct ggml_tensor * ggml_map_binary_impl_f32(
         is_node = true;
     }
 
+    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_save(ctx);
+
     struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
     *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
-    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_load(ctx);
 
     result->op = GGML_OP_MAP_BINARY;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -7180,6 +7198,150 @@ struct ggml_tensor * ggml_map_binary_inplace_f32(
     return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
 }
 
+// ggml_map_custom1
+
+struct ggml_tensor * ggml_map_custom1_impl_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        const  ggml_custom1_op_f32_t   fun,
+        bool   inplace) {
+    bool is_node = false;
+
+    if (!inplace && a->grad) {
+        is_node = true;
+    }
+
+    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
+    *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
+
+    ggml_scratch_load(ctx);
+
+    result->op = GGML_OP_MAP_CUSTOM1;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->opt[0] = addr_tensor;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_map_custom1_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        const  ggml_custom1_op_f32_t   fun) {
+    return ggml_map_custom1_impl_f32(ctx, a, fun, false);
+}
+
+struct ggml_tensor * ggml_map_custom1_inplace_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        const  ggml_custom1_op_f32_t   fun) {
+    return ggml_map_custom1_impl_f32(ctx, a, fun, true);
+}
+
+// ggml_map_custom2
+
+struct ggml_tensor * ggml_map_custom2_impl_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        struct ggml_tensor           * b,
+        const  ggml_custom2_op_f32_t   fun,
+        bool   inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad || b->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
+    *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
+
+    ggml_scratch_load(ctx);
+
+    result->op = GGML_OP_MAP_CUSTOM2;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+    result->opt[0] = addr_tensor;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_map_custom2_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        struct ggml_tensor           * b,
+        const  ggml_custom2_op_f32_t   fun) {
+    return ggml_map_custom2_impl_f32(ctx, a, b, fun, false);
+}
+
+struct ggml_tensor * ggml_map_custom2_inplace_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        struct ggml_tensor           * b,
+        const  ggml_custom2_op_f32_t   fun) {
+    return ggml_map_custom2_impl_f32(ctx, a, b, fun, true);
+}
+
+// ggml_map_custom3
+
+struct ggml_tensor * ggml_map_custom3_impl_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        struct ggml_tensor           * b,
+        struct ggml_tensor           * c,
+        const  ggml_custom3_op_f32_t   fun,
+        bool   inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad || b->grad || c->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
+    *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
+
+    ggml_scratch_load(ctx);
+
+    result->op = GGML_OP_MAP_CUSTOM3;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+    result->opt[0] = addr_tensor;
+    result->opt[1] = c;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_map_custom3_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        struct ggml_tensor           * b,
+        struct ggml_tensor           * c,
+        const  ggml_custom3_op_f32_t   fun) {
+    return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false);
+}
+
+struct ggml_tensor * ggml_map_custom3_inplace_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        struct ggml_tensor           * b,
+        struct ggml_tensor           * c,
+        const  ggml_custom3_op_f32_t   fun) {
+    return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true);
+}
+
 // ggml_cross_entropy_loss
 
 struct ggml_tensor * ggml_cross_entropy_loss(
@@ -14636,6 +14798,114 @@ static void ggml_compute_forward_map_binary(
     }
 }
 
+// ggml_compute_forward_map_custom1
+
+static void ggml_compute_forward_map_custom1_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * a,
+        struct ggml_tensor * dst,
+        const ggml_custom1_op_f32_t fun) {
+    assert(params->ith == 0);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    fun(dst, a);
+}
+
+
+static void ggml_compute_forward_map_custom1(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * a,
+        struct ggml_tensor * dst,
+        const ggml_custom1_op_f32_t fun) {
+    switch (a->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_map_custom1_f32(params, a, dst, fun);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_map_custom2
+
+static void ggml_compute_forward_map_custom2_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * a,
+        const struct ggml_tensor * b,
+        struct ggml_tensor * dst,
+        const ggml_custom2_op_f32_t fun) {
+    assert(params->ith == 0);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    fun(dst, a, b);
+}
+
+
+static void ggml_compute_forward_map_custom2(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * a,
+        const struct ggml_tensor * b,
+        struct ggml_tensor * dst,
+        const ggml_custom2_op_f32_t fun) {
+    switch (a->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_map_custom2_f32(params, a, b, dst, fun);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_map_custom3
+
+static void ggml_compute_forward_map_custom3_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * a,
+        const struct ggml_tensor * b,
+        const struct ggml_tensor * c,
+        struct ggml_tensor * dst,
+        const ggml_custom3_op_f32_t fun) {
+    assert(params->ith == 0);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    fun(dst, a, b, c);
+}
+
+
+static void ggml_compute_forward_map_custom3(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * a,
+        const struct ggml_tensor * b,
+        const struct ggml_tensor * c,
+        struct ggml_tensor * dst,
+        const ggml_custom3_op_f32_t fun) {
+    switch (a->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_map_custom3_f32(params, a, b, c, dst, fun);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_cross_entropy_loss
 
 static void ggml_compute_forward_cross_entropy_loss_f32(
@@ -15173,6 +15443,24 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
                 ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
             }
             break;
+        case GGML_OP_MAP_CUSTOM1:
+            {
+                const ggml_custom1_op_f32_t fun = *((ggml_custom1_op_f32_t *)tensor->opt[0]->data);
+                ggml_compute_forward_map_custom1(params, tensor->src0, tensor, fun);
+            }
+            break;
+        case GGML_OP_MAP_CUSTOM2:
+            {
+                const ggml_custom2_op_f32_t fun = *((ggml_custom2_op_f32_t *)tensor->opt[0]->data);
+                ggml_compute_forward_map_custom2(params, tensor->src0, tensor->src1, tensor, fun);
+            }
+            break;
+        case GGML_OP_MAP_CUSTOM3:
+            {
+                const ggml_custom3_op_f32_t fun = *((ggml_custom3_op_f32_t *)tensor->opt[0]->data);
+                ggml_compute_forward_map_custom3(params, tensor->src0, tensor->src1, tensor->opt[1], tensor, fun);
+            }
+            break;
         case GGML_OP_CROSS_ENTROPY_LOSS:
             {
                 ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
@@ -15979,6 +16267,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
         case GGML_OP_WIN_UNPART:
         case GGML_OP_MAP_UNARY:
         case GGML_OP_MAP_BINARY:
+        case GGML_OP_MAP_CUSTOM1:
+        case GGML_OP_MAP_CUSTOM2:
+        case GGML_OP_MAP_CUSTOM3:
             {
                 GGML_ASSERT(false); // not supported
             } break;
@@ -16620,6 +16911,9 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                 case GGML_OP_WIN_UNPART:
                 case GGML_OP_MAP_UNARY:
                 case GGML_OP_MAP_BINARY:
+                case GGML_OP_MAP_CUSTOM1:
+                case GGML_OP_MAP_CUSTOM2:
+                case GGML_OP_MAP_CUSTOM3:
                     {
                         node->n_tasks = 1;
                     } break;