]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml: fix gradient allocation logic (ggml/966)
authorJohannes Gäßler <redacted>
Sun, 29 Sep 2024 21:18:02 +0000 (23:18 +0200)
committerGeorgi Gerganov <redacted>
Tue, 1 Oct 2024 13:07:38 +0000 (16:07 +0300)
* ggml: fix gradient allocation logic

* gradient allocation in ggml_build_backward_expand

* fixup

* fix test-backend-ops grad

* suggestions by slaren

* fix test1.c

* fix legacy opt API

* fix test-grad0

* remove keep arg

ggml/include/ggml.h
ggml/src/ggml.c
tests/test-backend-ops.cpp
tests/test-grad0.cpp

index f46d4a8a65f023b1983946ff2dbba6120e523fa9..a8a74bee13de33b066547bace9373e07332b7f87 100644 (file)
@@ -577,10 +577,10 @@ extern "C" {
 
     // this tensor...
     enum ggml_tensor_flag {
-        GGML_TENSOR_FLAG_INPUT    = 1, // ...is an input for the GGML compute graph
-        GGML_TENSOR_FLAG_OUTPUT   = 2, // ...is an output for the GGML compute graph
-        GGML_TENSOR_FLAG_PARAM    = 4, // ...contains trainable parameters
-        GGML_TENSOR_FLAG_LOSS     = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
+        GGML_TENSOR_FLAG_INPUT   1, // ...is an input for the GGML compute graph
+        GGML_TENSOR_FLAG_OUTPUT  2, // ...is an output for the GGML compute graph
+        GGML_TENSOR_FLAG_PARAM   4, // ...contains trainable parameters
+        GGML_TENSOR_FLAG_LOSS    8, // ...defines loss for numerical optimization (multiple loss tensors add up)
     };
 
     // n-dimensional tensor
@@ -1410,14 +1410,14 @@ extern "C" {
     // supports 3D: a->ne[2] == b->ne[1]
     GGML_API struct ggml_tensor * ggml_get_rows(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
+            struct ggml_tensor  * a,  // data
+            struct ggml_tensor  * b); // row indices
 
     GGML_API struct ggml_tensor * ggml_get_rows_back(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            struct ggml_tensor  * c);
+            struct ggml_tensor  * a,  // gradients of ggml_get_rows result
+            struct ggml_tensor  * b,  // row indices
+            struct ggml_tensor  * c); // data for ggml_get_rows, only used for its shape
 
     GGML_API struct ggml_tensor * ggml_diag(
         struct ggml_context     * ctx,
@@ -1568,9 +1568,9 @@ extern "C" {
     // a - dy
     GGML_API struct ggml_tensor * ggml_rope_back(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            struct ggml_tensor  * c,
+            struct ggml_tensor  * a, // gradients of ggml_rope result
+            struct ggml_tensor  * b, // positions
+            struct ggml_tensor  * c, // freq factors
             int                   n_dims,
             int                   mode,
             int                   n_ctx_orig,
@@ -2036,15 +2036,15 @@ extern "C" {
     // loss function
 
     GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
-            struct ggml_context         * ctx,
-            struct ggml_tensor          * a,
-            struct ggml_tensor          * b);
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,  // logits
+            struct ggml_tensor  * b); // labels
 
     GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(
-            struct ggml_context         * ctx,
-            struct ggml_tensor          * a,
-            struct ggml_tensor          * b,
-            struct ggml_tensor          * c);
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,  // logits
+            struct ggml_tensor  * b,  // labels
+            struct ggml_tensor  * c); // gradients of cross_entropy_loss result
 
     // AdamW optimizer step
     // Paper: https://arxiv.org/pdf/1711.05101v3.pdf
@@ -2066,7 +2066,7 @@ extern "C" {
     GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
 
     GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
-    GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep);
+    GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate);
 
     GGML_API void ggml_build_opt_adamw(
             struct ggml_context * ctx,
index 81b651c6a438de5ac708b7a5299b30e52eca6fe0..fa8d6c25a982d08095254939ea5552fda94cf69a 100644 (file)
@@ -4725,18 +4725,11 @@ struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * nam
 
 static struct ggml_tensor * ggml_dup_impl(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
+        struct ggml_tensor  * a,
+        bool                  inplace) {
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    result->op   = GGML_OP_DUP;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_DUP;
     result->src[0] = a;
 
     return result;
@@ -4744,13 +4737,13 @@ static struct ggml_tensor * ggml_dup_impl(
 
 struct ggml_tensor * ggml_dup(
         struct ggml_context * ctx,
-        struct ggml_tensor * a) {
+        struct ggml_tensor  * a) {
     return ggml_dup_impl(ctx, a, false);
 }
 
 struct ggml_tensor * ggml_dup_inplace(
         struct ggml_context * ctx,
-        struct ggml_tensor * a) {
+        struct ggml_tensor  * a) {
     return ggml_dup_impl(ctx, a, true);
 }
 
@@ -4758,21 +4751,14 @@ struct ggml_tensor * ggml_dup_inplace(
 
 static struct ggml_tensor * ggml_add_impl(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        struct ggml_tensor * b,
-        bool inplace) {
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        bool                  inplace) {
     GGML_ASSERT(ggml_can_repeat(b, a));
 
-    bool is_node = false;
-
-    if (!inplace && (a->grad || b->grad)) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    result->op   = GGML_OP_ADD;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_ADD;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -4781,15 +4767,15 @@ static struct ggml_tensor * ggml_add_impl(
 
 struct ggml_tensor * ggml_add(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        struct ggml_tensor * b) {
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
     return ggml_add_impl(ctx, a, b, false);
 }
 
 struct ggml_tensor * ggml_add_inplace(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        struct ggml_tensor * b) {
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
     return ggml_add_impl(ctx, a, b, true);
 }
 
@@ -4797,9 +4783,9 @@ struct ggml_tensor * ggml_add_inplace(
 
 static struct ggml_tensor * ggml_add_cast_impl(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        struct ggml_tensor * b,
-        enum   ggml_type     type) {
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        enum   ggml_type      type) {
     // TODO: support less-strict constraint
     //       GGML_ASSERT(ggml_can_repeat(b, a));
     GGML_ASSERT(ggml_can_repeat_rows(b, a));
@@ -4809,18 +4795,9 @@ static struct ggml_tensor * ggml_add_cast_impl(
                 a->type == GGML_TYPE_F16 ||
                 a->type == GGML_TYPE_BF16);
 
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        // TODO: support backward pass for broadcasting
-        GGML_ASSERT(ggml_are_same_shape(a, b));
-        is_node = true;
-    }
-
     struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne);
 
-    result->op   = GGML_OP_ADD;
-    result->grad = is_node ? ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne) : NULL;
+    result->op     = GGML_OP_ADD;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -4829,9 +4806,9 @@ static struct ggml_tensor * ggml_add_cast_impl(
 
 struct ggml_tensor * ggml_add_cast(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        struct ggml_tensor * b,
-        enum   ggml_type     type) {
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        enum   ggml_type      type) {
     return ggml_add_cast_impl(ctx, a, b, type);
 }
 
@@ -4839,22 +4816,15 @@ struct ggml_tensor * ggml_add_cast(
 
 static struct ggml_tensor * ggml_add1_impl(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        struct ggml_tensor * b,
-        bool inplace) {
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        bool                  inplace) {
     GGML_ASSERT(ggml_is_scalar(b));
     GGML_ASSERT(ggml_is_padded_1d(a));
 
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    result->op   = GGML_OP_ADD1;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_ADD1;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -4863,15 +4833,15 @@ static struct ggml_tensor * ggml_add1_impl(
 
 struct ggml_tensor * ggml_add1(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        struct ggml_tensor * b) {
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
     return ggml_add1_impl(ctx, a, b, false);
 }
 
 struct ggml_tensor * ggml_add1_inplace(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        struct ggml_tensor * b) {
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
     return ggml_add1_impl(ctx, a, b, true);
 }
 
@@ -4879,31 +4849,24 @@ struct ggml_tensor * ggml_add1_inplace(
 
 static struct ggml_tensor * ggml_acc_impl(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        struct ggml_tensor * b,
-        size_t               nb1,
-        size_t               nb2,
-        size_t               nb3,
-        size_t               offset,
-        bool inplace) {
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        size_t                nb1,
+        size_t                nb2,
+        size_t                nb3,
+        size_t                offset,
+        bool                  inplace) {
     GGML_ASSERT(ggml_nelements(b) <= ggml_nelements(a));
     GGML_ASSERT(ggml_is_contiguous(a));
     GGML_ASSERT(a->type == GGML_TYPE_F32);
     GGML_ASSERT(b->type == GGML_TYPE_F32);
 
-    bool is_node = false;
-
-    if (!inplace && (a->grad || b->grad)) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op   = GGML_OP_ACC;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_ACC;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -4912,23 +4875,23 @@ static struct ggml_tensor * ggml_acc_impl(
 
 struct ggml_tensor * ggml_acc(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        struct ggml_tensor * b,
-        size_t               nb1,
-        size_t               nb2,
-        size_t               nb3,
-        size_t               offset) {
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        size_t                nb1,
+        size_t                nb2,
+        size_t                nb3,
+        size_t                offset) {
     return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
 }
 
 struct ggml_tensor * ggml_acc_inplace(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        struct ggml_tensor * b,
-        size_t               nb1,
-        size_t               nb2,
-        size_t               nb3,
-        size_t               offset) {
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        size_t                nb1,
+        size_t                nb2,
+        size_t                nb3,
+        size_t                offset) {
     return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
 }
 
@@ -4936,23 +4899,14 @@ struct ggml_tensor * ggml_acc_inplace(
 
 static struct ggml_tensor * ggml_sub_impl(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        struct ggml_tensor * b,
-        bool inplace) {
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        bool                  inplace) {
     GGML_ASSERT(ggml_can_repeat(b, a));
 
-    bool is_node = false;
-
-    if (!inplace && (a->grad || b->grad)) {
-        // TODO: support backward pass for broadcasting
-        GGML_ASSERT(ggml_are_same_shape(a, b));
-        is_node = true;
-    }
-
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    result->op   = GGML_OP_SUB;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_SUB;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -4961,15 +4915,15 @@ static struct ggml_tensor * ggml_sub_impl(
 
 struct ggml_tensor * ggml_sub(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        struct ggml_tensor * b) {
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
     return ggml_sub_impl(ctx, a, b, false);
 }
 
 struct ggml_tensor * ggml_sub_inplace(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        struct ggml_tensor * b) {
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
     return ggml_sub_impl(ctx, a, b, true);
 }
 
@@ -4977,27 +4931,14 @@ struct ggml_tensor * ggml_sub_inplace(
 
 static struct ggml_tensor * ggml_mul_impl(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        struct ggml_tensor * b,
-        bool inplace) {
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        bool                  inplace) {
     GGML_ASSERT(ggml_can_repeat(b, a));
 
-    bool is_node = false;
-
-    if (!inplace && (a->grad || b->grad)) {
-        // TODO: support backward pass for broadcasting
-        GGML_ASSERT(ggml_are_same_shape(a, b));
-        is_node = true;
-    }
-
-    if (inplace) {
-        GGML_ASSERT(!is_node);
-    }
-
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    result->op   = GGML_OP_MUL;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_MUL;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -5022,25 +4963,14 @@ struct ggml_tensor * ggml_mul_inplace(
 
 static struct ggml_tensor * ggml_div_impl(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        struct ggml_tensor * b,
-        bool inplace) {
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        bool                  inplace) {
     GGML_ASSERT(ggml_can_repeat(b, a));
 
-    bool is_node = false;
-
-    if (!inplace && (a->grad || b->grad)) {
-        is_node = true;
-    }
-
-    if (inplace) {
-        GGML_ASSERT(!is_node);
-    }
-
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    result->op   = GGML_OP_DIV;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_DIV;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -5065,18 +4995,11 @@ struct ggml_tensor * ggml_div_inplace(
 
 static struct ggml_tensor * ggml_sqr_impl(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
+        struct ggml_tensor  * a,
+        bool                  inplace) {
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    result->op   = GGML_OP_SQR;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_SQR;
     result->src[0] = a;
 
     return result;
@@ -5098,18 +5021,11 @@ struct ggml_tensor * ggml_sqr_inplace(
 
 static struct ggml_tensor * ggml_sqrt_impl(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
+        struct ggml_tensor  * a,
+        bool                  inplace) {
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    result->op   = GGML_OP_SQRT;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_SQRT;
     result->src[0] = a;
 
     return result;
@@ -5132,17 +5048,10 @@ struct ggml_tensor * ggml_sqrt_inplace(
 static struct ggml_tensor * ggml_log_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
+        bool                  inplace) {
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    result->op   = GGML_OP_LOG;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_LOG;
     result->src[0] = a;
 
     return result;
@@ -5165,17 +5074,10 @@ struct ggml_tensor * ggml_log_inplace(
 static struct ggml_tensor * ggml_sin_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
+        bool                  inplace) {
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    result->op   = GGML_OP_SIN;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_SIN;
     result->src[0] = a;
 
     return result;
@@ -5198,17 +5100,10 @@ struct ggml_tensor * ggml_sin_inplace(
 static struct ggml_tensor * ggml_cos_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
+        bool                  inplace) {
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    result->op   = GGML_OP_COS;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_COS;
     result->src[0] = a;
 
     return result;
@@ -5230,17 +5125,10 @@ struct ggml_tensor * ggml_cos_inplace(
 
 struct ggml_tensor * ggml_sum(
         struct ggml_context * ctx,
-        struct ggml_tensor * a) {
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
+        struct ggml_tensor  * a) {
     struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
 
-    result->op   = GGML_OP_SUM;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_SUM;
     result->src[0] = a;
 
     return result;
@@ -5250,13 +5138,7 @@ struct ggml_tensor * ggml_sum(
 
 struct ggml_tensor * ggml_sum_rows(
         struct ggml_context * ctx,
-        struct ggml_tensor * a) {
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
+        struct ggml_tensor  * a) {
     int64_t ne[GGML_MAX_DIMS] = { 1 };
     for (int i = 1; i < GGML_MAX_DIMS; ++i) {
         ne[i] = a->ne[i];
@@ -5264,8 +5146,7 @@ struct ggml_tensor * ggml_sum_rows(
 
     struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
 
-    result->op   = GGML_OP_SUM_ROWS;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_SUM_ROWS;
     result->src[0] = a;
 
     return result;
@@ -5275,19 +5156,11 @@ struct ggml_tensor * ggml_sum_rows(
 
 struct ggml_tensor * ggml_mean(
         struct ggml_context * ctx,
-        struct ggml_tensor * a) {
-    bool is_node = false;
-
-    if (a->grad) {
-        GGML_ABORT("fatal error"); // TODO: implement
-        is_node = true;
-    }
-
+        struct ggml_tensor  * a) {
     int64_t ne[4] = { 1, a->ne[1], a->ne[2], a->ne[3] };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
-    result->op   = GGML_OP_MEAN;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_MEAN;
     result->src[0] = a;
 
     return result;
@@ -5297,19 +5170,12 @@ struct ggml_tensor * ggml_mean(
 
 struct ggml_tensor * ggml_argmax(
         struct ggml_context * ctx,
-        struct ggml_tensor * a) {
+        struct ggml_tensor  * a) {
     GGML_ASSERT(ggml_is_matrix(a));
-    bool is_node = false;
-
-    if (a->grad) {
-        GGML_ABORT("fatal error");
-        is_node = true;
-    }
 
     struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]);
 
-    result->op   = GGML_OP_ARGMAX;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_ARGMAX;
     result->src[0] = a;
 
     return result;
@@ -5319,20 +5185,13 @@ struct ggml_tensor * ggml_argmax(
 
 struct ggml_tensor * ggml_repeat(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        struct ggml_tensor * b) {
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
     GGML_ASSERT(ggml_can_repeat(a, b));
 
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne);
 
-    result->op   = GGML_OP_REPEAT;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_REPEAT;
     result->src[0] = a;
 
     return result;
@@ -5342,24 +5201,13 @@ struct ggml_tensor * ggml_repeat(
 
 struct ggml_tensor * ggml_repeat_back(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        struct ggml_tensor * b) {
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
     GGML_ASSERT(ggml_can_repeat(b, a));
 
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
-    if (ggml_are_same_shape(a, b) && !is_node) {
-        return a;
-    }
-
     struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne);
 
-    result->op   = GGML_OP_REPEAT_BACK;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_REPEAT_BACK;
     result->src[0] = a;
 
     return result;
@@ -5369,9 +5217,9 @@ struct ggml_tensor * ggml_repeat_back(
 
 struct ggml_tensor * ggml_concat(
     struct ggml_context * ctx,
-    struct ggml_tensor * a,
-    struct ggml_tensor * b,
-    int dim) {
+    struct ggml_tensor  * a,
+    struct ggml_tensor  * b,
+    int                   dim) {
     GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
 
     int64_t ne[GGML_MAX_DIMS];
@@ -5384,19 +5232,11 @@ struct ggml_tensor * ggml_concat(
         ne[d] = a->ne[d];
     }
 
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        GGML_ABORT("fatal error"); // TODO: implement
-        is_node = true;
-    }
-
     struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
 
     ggml_set_op_params_i32(result, 0, dim);
 
-    result->op = GGML_OP_CONCAT;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_CONCAT;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -5505,20 +5345,14 @@ struct ggml_tensor * ggml_relu_inplace(
 
 struct ggml_tensor * ggml_leaky_relu(
         struct ggml_context * ctx,
-        struct ggml_tensor  * a, float negative_slope, bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        GGML_ABORT("fatal error"); // TODO: not implemented
-        is_node = true;
-    }
-
+        struct ggml_tensor  * a,
+        float                 negative_slope,
+        bool                  inplace) {
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
 
-    result->op   = GGML_OP_LEAKY_RELU;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_LEAKY_RELU;
     result->src[0] = a;
 
     return result;
@@ -5586,17 +5420,9 @@ struct ggml_tensor * ggml_silu_back(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b) {
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        // TODO: implement backward
-        is_node = true;
-    }
-
     struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
 
-    result->op   = GGML_OP_SILU_BACK;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_SILU_BACK;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -5604,6 +5430,7 @@ struct ggml_tensor * ggml_silu_back(
 }
 
 // ggml hardswish
+
 struct ggml_tensor * ggml_hardswish(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
@@ -5611,6 +5438,7 @@ struct ggml_tensor * ggml_hardswish(
 }
 
 // ggml hardsigmoid
+
 struct ggml_tensor * ggml_hardsigmoid(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
@@ -5618,6 +5446,7 @@ struct ggml_tensor * ggml_hardsigmoid(
 }
 
 // ggml exp
+
 struct ggml_tensor * ggml_exp(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
@@ -5635,21 +5464,13 @@ struct ggml_tensor * ggml_exp_inplace(
 static struct ggml_tensor * ggml_norm_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        float eps,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        GGML_ABORT("fatal error"); // TODO: implement backward
-        is_node = true;
-    }
-
+        float                 eps,
+        bool                  inplace) {
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     ggml_set_op_params(result, &eps, sizeof(eps));
 
-    result->op   = GGML_OP_NORM;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_NORM;
     result->src[0] = a;
 
     return result;
@@ -5658,14 +5479,14 @@ static struct ggml_tensor * ggml_norm_impl(
 struct ggml_tensor * ggml_norm(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        float eps) {
+        float                 eps) {
     return ggml_norm_impl(ctx, a, eps, false);
 }
 
 struct ggml_tensor * ggml_norm_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        float eps) {
+        float                 eps) {
     return ggml_norm_impl(ctx, a, eps, true);
 }
 
@@ -5674,20 +5495,13 @@ struct ggml_tensor * ggml_norm_inplace(
 static struct ggml_tensor * ggml_rms_norm_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        float eps,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
+        float                 eps,
+        bool                  inplace) {
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     ggml_set_op_params(result, &eps, sizeof(eps));
 
-    result->op   = GGML_OP_RMS_NORM;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_RMS_NORM;
     result->src[0] = a;
 
     return result;
@@ -5696,14 +5510,14 @@ static struct ggml_tensor * ggml_rms_norm_impl(
 struct ggml_tensor * ggml_rms_norm(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        float  eps) {
+        float                 eps) {
     return ggml_rms_norm_impl(ctx, a, eps, false);
 }
 
 struct ggml_tensor * ggml_rms_norm_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        float eps) {
+        float                 eps) {
     return ggml_rms_norm_impl(ctx, a, eps, true);
 }
 
@@ -5713,20 +5527,12 @@ struct ggml_tensor * ggml_rms_norm_back(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
-        float  eps) {
-    bool is_node = false;
-
-    if (a->grad) {
-        // TODO: implement backward
-        is_node = true;
-    }
-
+        float                 eps) {
     struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
 
     ggml_set_op_params(result, &eps, sizeof(eps));
 
-    result->op   = GGML_OP_RMS_NORM_BACK;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_RMS_NORM_BACK;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -5736,43 +5542,35 @@ struct ggml_tensor * ggml_rms_norm_back(
 // ggml_group_norm
 
 static struct ggml_tensor * ggml_group_norm_impl(
-    struct ggml_context * ctx,
-    struct ggml_tensor * a,
-    int n_groups,
-    float eps,
-    bool inplace) {
-
-    bool is_node = false;
-    if (!inplace && (a->grad)) {
-        GGML_ABORT("fatal error"); // TODO: implement backward
-        is_node = true;
-    }
-
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_groups,
+        float                 eps,
+        bool                  inplace) {
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     ggml_set_op_params_i32(result, 0, n_groups);
     ggml_set_op_params_f32(result, 1, eps);
 
-    result->op = GGML_OP_GROUP_NORM;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_GROUP_NORM;
     result->src[0] = a;
 
     return result;
 }
 
 struct ggml_tensor * ggml_group_norm(
-    struct ggml_context * ctx,
-    struct ggml_tensor * a,
-    int n_groups,
-    float eps) {
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_groups,
+        float                 eps) {
     return ggml_group_norm_impl(ctx, a, n_groups, eps, false);
 }
 
 struct ggml_tensor * ggml_group_norm_inplace(
-    struct ggml_context * ctx,
-    struct ggml_tensor * a,
-    int n_groups,
-    float eps) {
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_groups,
+        float                 eps) {
     return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
 }
 
@@ -5785,17 +5583,10 @@ struct ggml_tensor * ggml_mul_mat(
     GGML_ASSERT(ggml_can_mul_mat(a, b));
     GGML_ASSERT(!ggml_is_transposed(a));
 
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        is_node = true;
-    }
-
     const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
-    result->op   = GGML_OP_MUL_MAT;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_MUL_MAT;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -5841,17 +5632,10 @@ struct ggml_tensor * ggml_mul_mat_id(
     GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
     GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
 
-    bool is_node = false;
-
-    if (as->grad || b->grad) {
-        is_node = true;
-    }
-
     const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
-    result->op   = GGML_OP_MUL_MAT_ID;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_MUL_MAT_ID;
     result->src[0] = as;
     result->src[1] = b;
     result->src[2] = ids;
@@ -5868,18 +5652,11 @@ struct ggml_tensor * ggml_out_prod(
     GGML_ASSERT(ggml_can_out_prod(a, b));
     GGML_ASSERT(!ggml_is_transposed(a));
 
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        is_node = true;
-    }
-
     // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]
     const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
-    result->op   = GGML_OP_OUT_PROD;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_OUT_PROD;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -5892,21 +5669,14 @@ static struct ggml_tensor * ggml_scale_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         float                 s,
-        bool inplace) {
+        bool                  inplace) {
     GGML_ASSERT(ggml_is_padded_1d(a));
 
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     ggml_set_op_params(result, &s, sizeof(s));
 
-    result->op   = GGML_OP_SCALE;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_SCALE;
     result->src[0] = a;
 
     return result;
@@ -5914,15 +5684,15 @@ static struct ggml_tensor * ggml_scale_impl(
 
 struct ggml_tensor * ggml_scale(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        float                s) {
+        struct ggml_tensor  * a,
+        float                 s) {
     return ggml_scale_impl(ctx, a, s, false);
 }
 
 struct ggml_tensor * ggml_scale_inplace(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        float                s) {
+        struct ggml_tensor  * a,
+        float                 s) {
     return ggml_scale_impl(ctx, a, s, true);
 }
 
@@ -5936,15 +5706,9 @@ static struct ggml_tensor * ggml_set_impl(
         size_t                nb2,
         size_t                nb3,
         size_t                offset,
-        bool inplace) {
+        bool                  inplace) {
     GGML_ASSERT(ggml_nelements(a) >= ggml_nelements(b));
 
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        is_node = true;
-    }
-
     // make a view of the destination
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
@@ -5952,8 +5716,7 @@ static struct ggml_tensor * ggml_set_impl(
     int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op   = GGML_OP_SET;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_SET;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -5962,8 +5725,8 @@ static struct ggml_tensor * ggml_set_impl(
 
 struct ggml_tensor * ggml_set(
         struct ggml_context * ctx,
-        struct ggml_tensor  a,
-        struct ggml_tensor  b,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
         size_t                nb1,
         size_t                nb2,
         size_t                nb3,
@@ -5973,8 +5736,8 @@ struct ggml_tensor * ggml_set(
 
 struct ggml_tensor * ggml_set_inplace(
         struct ggml_context * ctx,
-        struct ggml_tensor  a,
-        struct ggml_tensor  b,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
         size_t                nb1,
         size_t                nb2,
         size_t                nb3,
@@ -5984,24 +5747,24 @@ struct ggml_tensor * ggml_set_inplace(
 
 struct ggml_tensor * ggml_set_1d(
         struct ggml_context * ctx,
-        struct ggml_tensor  a,
-        struct ggml_tensor  b,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
         size_t                offset) {
     return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, false);
 }
 
 struct ggml_tensor * ggml_set_1d_inplace(
         struct ggml_context * ctx,
-        struct ggml_tensor  a,
-        struct ggml_tensor  b,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
         size_t                offset) {
     return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, true);
 }
 
 struct ggml_tensor * ggml_set_2d(
         struct ggml_context * ctx,
-        struct ggml_tensor  a,
-        struct ggml_tensor  b,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
         size_t                nb1,
         size_t                offset) {
     return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false);
@@ -6009,8 +5772,8 @@ struct ggml_tensor * ggml_set_2d(
 
 struct ggml_tensor * ggml_set_2d_inplace(
         struct ggml_context * ctx,
-        struct ggml_tensor  a,
-        struct ggml_tensor  b,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
         size_t                nb1,
         size_t                offset) {
     return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true);
@@ -6024,13 +5787,6 @@ static struct ggml_tensor * ggml_cpy_impl(
         struct ggml_tensor  * b) {
     GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
 
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        // inplace is false and either one have a grad
-        is_node = true;
-    }
-
     // make a view of the destination
     struct ggml_tensor * result = ggml_view_tensor(ctx, b);
     if (strlen(b->name) > 0) {
@@ -6039,8 +5795,7 @@ static struct ggml_tensor * ggml_cpy_impl(
         ggml_format_name(result, "%s (copy)", a->name);
     }
 
-    result->op   = GGML_OP_CPY;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_CPY;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -6058,15 +5813,11 @@ struct ggml_tensor * ggml_cast(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         enum   ggml_type      type) {
-    bool is_node = false;
-
     struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne);
     ggml_format_name(result, "%s (copy)", a->name);
 
-    result->op   = GGML_OP_CPY;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_CPY;
     result->src[0] = a;
-    result->src[1] = result;
 
     return result;
 }
@@ -6076,17 +5827,10 @@ struct ggml_tensor * ggml_cast(
 static struct ggml_tensor * ggml_cont_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
     ggml_format_name(result, "%s (cont)", a->name);
 
-    result->op   = GGML_OP_CONT;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_CONT;
     result->src[0] = a;
 
     return result;
@@ -6132,13 +5876,10 @@ struct ggml_tensor * ggml_cont_4d(
         int64_t               ne3) {
     GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3));
 
-    bool is_node = false;
-
     struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
     ggml_format_name(result, "%s (cont)", a->name);
 
-    result->op   = GGML_OP_CONT;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_CONT;
     result->src[0] = a;
 
     return result;
@@ -6154,22 +5895,10 @@ struct ggml_tensor * ggml_reshape(
     // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous.
     GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
 
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
-    if (b->grad) {
-        // gradient propagation is not supported
-        //GGML_ABORT("fatal error");
-    }
-
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b->ne, a, 0);
     ggml_format_name(result, "%s (reshaped)", a->name);
 
-    result->op   = GGML_OP_RESHAPE;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_RESHAPE;
     result->src[0] = a;
 
     return result;
@@ -6182,18 +5911,11 @@ struct ggml_tensor * ggml_reshape_1d(
     GGML_ASSERT(ggml_is_contiguous(a));
     GGML_ASSERT(ggml_nelements(a) == ne0);
 
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
     const int64_t ne[1] = { ne0 };
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, ne, a, 0);
     ggml_format_name(result, "%s (reshaped)", a->name);
 
-    result->op   = GGML_OP_RESHAPE;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_RESHAPE;
     result->src[0] = a;
 
     return result;
@@ -6207,18 +5929,11 @@ struct ggml_tensor * ggml_reshape_2d(
     GGML_ASSERT(ggml_is_contiguous(a));
     GGML_ASSERT(ggml_nelements(a) == ne0*ne1);
 
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
     const int64_t ne[2] = { ne0, ne1 };
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a, 0);
     ggml_format_name(result, "%s (reshaped)", a->name);
 
-    result->op   = GGML_OP_RESHAPE;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_RESHAPE;
     result->src[0] = a;
 
     return result;
@@ -6233,18 +5948,11 @@ struct ggml_tensor * ggml_reshape_3d(
     GGML_ASSERT(ggml_is_contiguous(a));
     GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2);
 
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
     const int64_t ne[3] = { ne0, ne1, ne2 };
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a, 0);
     ggml_format_name(result, "%s (reshaped)", a->name);
 
-    result->op   = GGML_OP_RESHAPE;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_RESHAPE;
     result->src[0] = a;
 
     return result;
@@ -6260,18 +5968,11 @@ struct ggml_tensor * ggml_reshape_4d(
     GGML_ASSERT(ggml_is_contiguous(a));
     GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3);
 
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
     const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0);
     ggml_format_name(result, "%s (reshaped)", a->name);
 
-    result->op   = GGML_OP_RESHAPE;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_RESHAPE;
     result->src[0] = a;
 
     return result;
@@ -6283,20 +5984,12 @@ static struct ggml_tensor * ggml_view_impl(
         int                   n_dims,
         const int64_t       * ne,
         size_t                offset) {
-
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, n_dims, ne, a, offset);
     ggml_format_name(result, "%s (view)", a->name);
 
     ggml_set_op_params(result, &offset, sizeof(offset));
 
-    result->op   = GGML_OP_VIEW;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_VIEW;
     result->src[0] = a;
 
     return result;
@@ -6309,7 +6002,6 @@ struct ggml_tensor * ggml_view_1d(
         struct ggml_tensor  * a,
         int64_t               ne0,
         size_t                offset) {
-
     struct ggml_tensor * result = ggml_view_impl(ctx, a, 1, &ne0, offset);
 
     return result;
@@ -6324,7 +6016,6 @@ struct ggml_tensor * ggml_view_2d(
         int64_t               ne1,
         size_t                nb1,
         size_t                offset) {
-
     const int64_t ne[2] = { ne0, ne1 };
 
     struct ggml_tensor * result = ggml_view_impl(ctx, a, 2, ne, offset);
@@ -6347,7 +6038,6 @@ struct ggml_tensor * ggml_view_3d(
         size_t                nb1,
         size_t                nb2,
         size_t                offset) {
-
     const int64_t ne[3] = { ne0, ne1, ne2 };
 
     struct ggml_tensor * result = ggml_view_impl(ctx, a, 3, ne, offset);
@@ -6372,7 +6062,6 @@ struct ggml_tensor * ggml_view_4d(
         size_t                nb2,
         size_t                nb3,
         size_t                offset) {
-
     const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
 
     struct ggml_tensor * result = ggml_view_impl(ctx, a, 4, ne, offset);
@@ -6405,12 +6094,6 @@ struct ggml_tensor * ggml_permute(
     GGML_ASSERT(axis1 != axis3);
     GGML_ASSERT(axis2 != axis3);
 
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result = ggml_view_tensor(ctx, a);
     ggml_format_name(result, "%s (permuted)", a->name);
 
@@ -6437,8 +6120,7 @@ struct ggml_tensor * ggml_permute(
     result->nb[2] = nb[2];
     result->nb[3] = nb[3];
 
-    result->op   = GGML_OP_PERMUTE;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_PERMUTE;
     result->src[0] = a;
 
     int32_t params[] = { axis0, axis1, axis2, axis3 };
@@ -6452,12 +6134,6 @@ struct ggml_tensor * ggml_permute(
 struct ggml_tensor * ggml_transpose(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result = ggml_view_tensor(ctx, a);
     ggml_format_name(result, "%s (transposed)", a->name);
 
@@ -6467,8 +6143,7 @@ struct ggml_tensor * ggml_transpose(
     result->nb[0] = a->nb[1];
     result->nb[1] = a->nb[0];
 
-    result->op   = GGML_OP_TRANSPOSE;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_TRANSPOSE;
     result->src[0] = a;
 
     return result;
@@ -6484,12 +6159,6 @@ struct ggml_tensor * ggml_get_rows(
     GGML_ASSERT(b->ne[3] == 1);
     GGML_ASSERT(b->type == GGML_TYPE_I32);
 
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        is_node = true;
-    }
-
     // TODO: implement non F32 return
     enum ggml_type type = GGML_TYPE_F32;
     if (a->type == GGML_TYPE_I32) {
@@ -6497,8 +6166,7 @@ struct ggml_tensor * ggml_get_rows(
     }
     struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
 
-    result->op   = GGML_OP_GET_ROWS;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_GET_ROWS;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -6515,18 +6183,11 @@ struct ggml_tensor * ggml_get_rows_back(
     GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
     GGML_ASSERT(ggml_is_matrix(c) && (a->ne[0] == c->ne[0]));
 
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        is_node = true;
-    }
-
     // TODO: implement non F32 return
     //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
     struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, c->ne[0], c->ne[1]);
 
-    result->op   = GGML_OP_GET_ROWS_BACK;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_GET_ROWS_BACK;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -6539,17 +6200,11 @@ struct ggml_tensor * ggml_diag(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
     GGML_ASSERT(a->ne[1] == 1);
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
 
     const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] };
     struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 4, ne);
 
-    result->op   = GGML_OP_DIAG;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_DIAG;
     result->src[0] = a;
 
     return result;
@@ -6562,19 +6217,12 @@ static struct ggml_tensor * ggml_diag_mask_inf_impl(
         struct ggml_tensor  * a,
         int                   n_past,
         bool                  inplace) {
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     int32_t params[] = { n_past };
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op   = GGML_OP_DIAG_MASK_INF;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_DIAG_MASK_INF;
     result->src[0] = a;
 
     return result;
@@ -6601,19 +6249,12 @@ static struct ggml_tensor * ggml_diag_mask_zero_impl(
         struct ggml_tensor  * a,
         int                   n_past,
         bool                  inplace) {
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     int32_t params[] = { n_past };
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op   = GGML_OP_DIAG_MASK_ZERO;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_DIAG_MASK_ZERO;
     result->src[0] = a;
 
     return result;
@@ -6656,19 +6297,12 @@ static struct ggml_tensor * ggml_soft_max_impl(
         GGML_ASSERT(mask);
     }
 
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     float params[] = { scale, max_bias };
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op   = GGML_OP_SOFT_MAX;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_SOFT_MAX;
     result->src[0] = a;
     result->src[1] = mask;
 
@@ -6703,16 +6337,9 @@ static struct ggml_tensor * ggml_soft_max_back_impl(
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
         bool                  inplace) {
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        is_node = true; // TODO : implement backward pass
-    }
-
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    result->op   = GGML_OP_SOFT_MAX_BACK;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_SOFT_MAX_BACK;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -6761,12 +6388,6 @@ static struct ggml_tensor * ggml_rope_impl(
         GGML_ASSERT(c->ne[0] >= n_dims / 2);
     }
 
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
@@ -6778,8 +6399,7 @@ static struct ggml_tensor * ggml_rope_impl(
     memcpy(params + 10, &beta_slow,    sizeof(float));
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op   = GGML_OP_ROPE;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_ROPE;
     result->src[0] = a;
     result->src[1] = b;
     result->src[2] = c;
@@ -6907,13 +6527,6 @@ struct ggml_tensor * ggml_rope_back(
     GGML_ASSERT(b->type == GGML_TYPE_I32);
     GGML_ASSERT(a->ne[2] == b->ne[0]);
 
-    bool is_node = false;
-
-    if (a->grad) {
-        GGML_ASSERT(false && "backwards pass not implemented");
-        is_node = false;
-    }
-
     struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
 
     int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
@@ -6925,8 +6538,7 @@ struct ggml_tensor * ggml_rope_back(
     memcpy(params + 10, &beta_slow,    sizeof(float));
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op   = GGML_OP_ROPE_BACK;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_ROPE_BACK;
     result->src[0] = a;
     result->src[1] = b;
     result->src[2] = c;
@@ -6941,21 +6553,13 @@ struct ggml_tensor * ggml_clamp(
         struct ggml_tensor  * a,
         float                 min,
         float                 max) {
-    bool is_node = false;
-
-    if (a->grad) {
-        GGML_ABORT("fatal error"); // TODO: implement backward
-        is_node = true;
-    }
-
     // TODO: when implement backward, fix this:
     struct ggml_tensor * result = ggml_view_tensor(ctx, a);
 
     float params[] = { min, max };
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op   = GGML_OP_CLAMP;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_CLAMP;
     result->src[0] = a;
 
     return result;
@@ -7017,13 +6621,6 @@ GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
     GGML_ASSERT(p0 == 0);
     GGML_ASSERT(d0 == 1);
 
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        GGML_ABORT("fatal error"); // TODO: implement backward
-        is_node = true;
-    }
-
     const int64_t ne[4] = {
         ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
         a->ne[1], b->ne[2], 1,
@@ -7033,8 +6630,7 @@ GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
     int32_t params[] = { s0, p0, d0 };
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op = GGML_OP_CONV_TRANSPOSE_1D;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_CONV_TRANSPOSE_1D;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -7042,17 +6638,17 @@ GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
 }
 
 // ggml_conv_depthwise
-struct ggml_tensor * ggml_conv_depthwise_2d(
-    struct ggml_context * ctx,
-    struct ggml_tensor * a,
-    struct ggml_tensor * b,
-    int                  s0,
-    int                  s1,
-    int                  p0,
-    int                  p1,
-    int                  d0,
-    int                  d1) {
 
+struct ggml_tensor * ggml_conv_depthwise_2d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   s0,
+        int                   s1,
+        int                   p0,
+        int                   p1,
+        int                   d0,
+        int                   d1) {
     struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
     struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
                                         ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
@@ -7072,29 +6668,23 @@ struct ggml_tensor * ggml_conv_depthwise_2d(
 // b: [N, IC, IH, IW]
 // result: [N, OH, OW, IC*KH*KW]
 struct ggml_tensor * ggml_im2col(
-    struct ggml_context * ctx,
-    struct ggml_tensor  * a,
-    struct ggml_tensor  * b,
-    int                  s0,
-    int                  s1,
-    int                  p0,
-    int                  p1,
-    int                  d0,
-    int                  d1,
-    bool                 is_2D,
-    enum ggml_type       dst_type) {
-
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   s0,
+        int                   s1,
+        int                   p0,
+        int                   p1,
+        int                   d0,
+        int                   d1,
+        bool                  is_2D,
+        enum ggml_type        dst_type) {
     if(is_2D) {
         GGML_ASSERT(a->ne[2] == b->ne[2]);
     } else {
         GGML_ASSERT(a->ne[1] == b->ne[1]);
         GGML_ASSERT(b->ne[3] == 1);
     }
-    bool is_node = false;
-
-    if (/*a->grad ||*/ b->grad) { // a is only used for its shape, not its data
-        is_node = true;
-    }
 
     const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
     const int64_t OW =         ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
@@ -7113,8 +6703,7 @@ struct ggml_tensor * ggml_im2col(
     int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op = GGML_OP_IM2COL;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_IM2COL;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -7122,30 +6711,22 @@ struct ggml_tensor * ggml_im2col(
 }
 
 struct ggml_tensor * ggml_im2col_back(
-    struct ggml_context * ctx,
-    struct ggml_tensor  * a,
-    struct ggml_tensor  * b,
-    int64_t             * ne,
-    int                   s0,
-    int                   s1,
-    int                   p0,
-    int                   p1,
-    int                   d0,
-    int                   d1,
-    bool                  is_2D) {
-
-    bool is_node = false;
-
-    if (/*a->grad ||*/ b->grad) { // a is only used for its shape, not its data
-        is_node = true;
-    }
-
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int64_t             * ne,
+        int                   s0,
+        int                   s1,
+        int                   p0,
+        int                   p1,
+        int                   d0,
+        int                   d1,
+        bool                  is_2D) {
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
     int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op = GGML_OP_IM2COL_BACK;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_IM2COL_BACK;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -7159,12 +6740,12 @@ struct ggml_tensor * ggml_conv_2d(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
-        int                  s0,
-        int                  s1,
-        int                  p0,
-        int                  p1,
-        int                  d0,
-        int                  d1) {
+        int                   s0,
+        int                   s1,
+        int                   p0,
+        int                   p1,
+        int                   d0,
+        int                   d1) {
     struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW]
 
     struct ggml_tensor * result =
@@ -7180,6 +6761,7 @@ struct ggml_tensor * ggml_conv_2d(
 }
 
 // ggml_conv_2d_sk_p0
+
 struct ggml_tensor * ggml_conv_2d_sk_p0(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
@@ -7209,13 +6791,6 @@ struct ggml_tensor * ggml_conv_transpose_2d_p0(
         int                   stride) {
     GGML_ASSERT(a->ne[3] == b->ne[2]);
 
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        GGML_ABORT("fatal error"); // TODO: implement backward
-        is_node = true;
-    }
-
     const int64_t ne[4] = {
         ggml_calc_conv_transpose_output_size(b->ne[0], a->ne[0], stride, 0 /*p0*/),
         ggml_calc_conv_transpose_output_size(b->ne[1], a->ne[1], stride, 0 /*p1*/),
@@ -7226,8 +6801,7 @@ struct ggml_tensor * ggml_conv_transpose_2d_p0(
 
     ggml_set_op_params_i32(result, 0, stride);
 
-    result->op = GGML_OP_CONV_TRANSPOSE_2D;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_CONV_TRANSPOSE_2D;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -7247,16 +6821,8 @@ struct ggml_tensor * ggml_pool_1d(
         struct ggml_tensor  * a,
         enum ggml_op_pool     op,
         int                   k0,
-        int                   s0,
-        int                   p0) {
-
-    bool is_node = false;
-
-    if (a->grad) {
-        GGML_ABORT("fatal error"); // TODO: implement backward
-        is_node = true;
-    }
-
+        int                   s0,
+        int                   p0) {
     const int64_t ne[4] = {
         ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
         a->ne[1],
@@ -7268,8 +6834,7 @@ struct ggml_tensor * ggml_pool_1d(
     int32_t params[] = { op, k0, s0, p0 };
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op = GGML_OP_POOL_1D;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_POOL_1D;
     result->src[0] = a;
 
     return result;
@@ -7287,13 +6852,6 @@ struct ggml_tensor * ggml_pool_2d(
         int                   s1,
         float                 p0,
         float                 p1) {
-
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result;
     const int64_t ne[4] = {
         ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
@@ -7306,9 +6864,9 @@ struct ggml_tensor * ggml_pool_2d(
     int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op = GGML_OP_POOL_2D;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_POOL_2D;
     result->src[0] = a;
+
     return result;
 }
 
@@ -7323,100 +6881,74 @@ struct ggml_tensor * ggml_pool_2d_back(
         int                   s1,
         float                 p0,
         float                 p1) {
-
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result;
     result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, af->ne);
 
     int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op = GGML_OP_POOL_2D_BACK;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_POOL_2D_BACK;
     result->src[0] = a;
     result->src[1] = af;
+
     return result;
 }
 
 // ggml_upscale
 
 static struct ggml_tensor * ggml_upscale_impl(
-    struct ggml_context * ctx,
-    struct ggml_tensor * a,
-    int ne0,
-    int ne1,
-    int ne2,
-    int ne3) {
-    bool is_node = false;
-
-    if (a->grad) {
-        GGML_ABORT("fatal error"); // TODO: implement backward
-        is_node = true;
-    }
-
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   ne0,
+        int                   ne1,
+        int                   ne2,
+        int                   ne3) {
     GGML_ASSERT(a->ne[0] <= ne0);
     GGML_ASSERT(a->ne[1] <= ne1);
     GGML_ASSERT(a->ne[2] <= ne2);
     GGML_ASSERT(a->ne[3] <= ne3);
 
-    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
-            ne0,
-            ne1,
-            ne2,
-            ne3
-            );
-
-    result->op = GGML_OP_UPSCALE;
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
 
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_UPSCALE;
     result->src[0] = a;
 
     return result;
 }
 
 struct ggml_tensor * ggml_upscale(
-    struct ggml_context * ctx,
-    struct ggml_tensor * a,
-    int scale_factor) {
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   scale_factor) {
     return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]);
 }
 
 struct ggml_tensor * ggml_upscale_ext(
-    struct ggml_context * ctx,
-    struct ggml_tensor * a,
-    int ne0,
-    int ne1,
-    int ne2,
-    int ne3) {
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   ne0,
+        int                   ne1,
+        int                   ne2,
+        int                   ne3) {
     return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
 }
 
 // ggml_pad
 
 struct ggml_tensor * ggml_pad(
-    struct ggml_context * ctx,
-    struct ggml_tensor  * a,
-    int p0, int p1, int p2, int p3) {
-    bool is_node = false;
-
-    if (a->grad) {
-        GGML_ABORT("fatal error"); // TODO: implement backward
-        is_node = true;
-    }
-
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   p0,
+        int                   p1,
+        int                   p2,
+        int                   p3) {
     struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
             a->ne[0] + p0,
             a->ne[1] + p1,
             a->ne[2] + p2,
             a->ne[3] + p3);
 
-    result->op = GGML_OP_PAD;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_PAD;
     result->src[0] = a;
 
     return result;
@@ -7425,39 +6957,32 @@ struct ggml_tensor * ggml_pad(
 // ggml_arange
 
 struct ggml_tensor * ggml_arange(
-    struct ggml_context * ctx,
-    float start,
-    float stop,
-    float step) {
-
+        struct ggml_context * ctx,
+        float                 start,
+        float                 stop,
+        float                 step) {
     GGML_ASSERT(stop > start);
 
     const int64_t steps = (int64_t) ceilf((stop - start) / step);
 
     struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
 
-    result->op = GGML_OP_ARANGE;
     ggml_set_op_params_f32(result, 0, start);
     ggml_set_op_params_f32(result, 1, stop);
     ggml_set_op_params_f32(result, 2, step);
 
+    result->op = GGML_OP_ARANGE;
+
     return result;
 }
 
 // ggml_timestep_embedding
 
 struct ggml_tensor * ggml_timestep_embedding(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * timesteps,
-            int                   dim,
-            int                   max_period) {
-    bool is_node = false;
-
-    if (timesteps->grad) {
-        GGML_ABORT("fatal error"); // TODO: implement backward
-        is_node = true;
-    }
-
+        struct ggml_context * ctx,
+        struct ggml_tensor  * timesteps,
+        int                   dim,
+        int                   max_period) {
     int actual_dim = dim;
     if (dim % 2 != 0) {
         actual_dim = dim + 1;
@@ -7465,11 +6990,10 @@ struct ggml_tensor * ggml_timestep_embedding(
 
     struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, actual_dim, timesteps->ne[0]);
 
-    result->op = GGML_OP_TIMESTEP_EMBEDDING;
     ggml_set_op_params_i32(result, 0, dim);
     ggml_set_op_params_i32(result, 1, max_period);
 
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_TIMESTEP_EMBEDDING;
     result->src[0] = timesteps;
 
     return result;
@@ -7478,22 +7002,14 @@ struct ggml_tensor * ggml_timestep_embedding(
 // ggml_argsort
 
 struct ggml_tensor * ggml_argsort(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        enum ggml_sort_order  order) {
-    bool is_node = false;
-
-    if (a->grad) {
-        GGML_ABORT("fatal error"); // TODO: not implemented
-        is_node = true;
-    }
-
+        struct ggml_context  * ctx,
+        struct ggml_tensor   * a,
+        enum ggml_sort_order   order) {
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
 
     ggml_set_op_params_i32(result, 0, (int32_t) order);
 
-    result->op   = GGML_OP_ARGSORT;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_ARGSORT;
     result->src[0] = a;
 
     return result;
@@ -7546,10 +7062,6 @@ struct ggml_tensor * ggml_flash_attn_ext(
 
     bool is_node = false;
 
-    if (q->grad || k->grad || v->grad) {
-        is_node = true;
-    }
-
     // permute(0, 2, 1, 3)
     int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
@@ -7676,17 +7188,9 @@ struct ggml_tensor * ggml_ssm_conv(
     GGML_ASSERT(sx->ne[1] == d_inner);
     GGML_ASSERT(n_t >= 0);
 
-    bool is_node = false;
-
-    if (sx->grad || c->grad) {
-        GGML_ABORT("fatal error"); // TODO: implement
-        is_node = true;
-    }
-
     struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s);
 
-    result->op   = GGML_OP_SSM_CONV;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_SSM_CONV;
     result->src[0] = sx;
     result->src[1] = c;
 
@@ -7730,18 +7234,10 @@ struct ggml_tensor * ggml_ssm_scan(
         GGML_ASSERT(B->ne[2] == n_seqs);
     }
 
-    bool is_node = false;
-
-    if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad) {
-        GGML_ABORT("fatal error"); // TODO: implement
-        is_node = true;
-    }
-
     // concatenated y + ssm_states
     struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
 
     result->op   = GGML_OP_SSM_SCAN;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = s;
     result->src[1] = x;
     result->src[2] = dt;
@@ -7761,13 +7257,6 @@ struct ggml_tensor * ggml_win_part(
     GGML_ASSERT(a->ne[3] == 1);
     GGML_ASSERT(a->type  == GGML_TYPE_F32);
 
-    bool is_node = false;
-
-    if (a->grad) {
-        GGML_ABORT("fatal error"); // TODO: implement backward
-        is_node = true;
-    }
-
     // padding
     const int px = (w - a->ne[1]%w)%w;
     const int py = (w - a->ne[2]%w)%w;
@@ -7782,8 +7271,7 @@ struct ggml_tensor * ggml_win_part(
     int32_t params[] = { npx, npy, w };
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op   = GGML_OP_WIN_PART;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_WIN_PART;
     result->src[0] = a;
 
     return result;
@@ -7799,21 +7287,13 @@ struct ggml_tensor * ggml_win_unpart(
         int                   w) {
     GGML_ASSERT(a->type == GGML_TYPE_F32);
 
-    bool is_node = false;
-
-    if (a->grad) {
-        GGML_ABORT("fatal error"); // TODO: implement backward
-        is_node = true;
-    }
-
     const int64_t ne[4] = { a->ne[0], w0, h0, 1, };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
 
     int32_t params[] = { w };
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op   = GGML_OP_WIN_UNPART;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_WIN_UNPART;
     result->src[0] = a;
 
     return result;
@@ -7829,18 +7309,10 @@ struct ggml_tensor * ggml_get_rel_pos(
     GGML_ASSERT(qh == kh);
     GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]);
 
-    bool is_node = false;
-
-    if (a->grad) {
-        GGML_ABORT("fatal error"); // TODO: implement backward
-        is_node = true;
-    }
-
     const int64_t ne[4] = { a->ne[0], kh, qh, 1, };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 3, ne);
 
-    result->op   = GGML_OP_GET_REL_POS;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_GET_REL_POS;
     result->src[0] = a;
 
     return result;
@@ -7864,17 +7336,10 @@ static struct ggml_tensor * ggml_add_rel_pos_impl(
     GGML_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]);
     GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]);
 
-    bool is_node = false;
-
-    if (!inplace && (a->grad || pw->grad || ph->grad)) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
     ggml_set_op_params_i32(result, 0, inplace ? 1 : 0);
 
-    result->op   = GGML_OP_ADD_REL_POS;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_ADD_REL_POS;
     result->src[0] = a;
     result->src[1] = pw;
     result->src[2] = ph;
@@ -7902,12 +7367,12 @@ struct ggml_tensor * ggml_add_rel_pos_inplace(
 
 struct ggml_tensor * ggml_rwkv_wkv(
         struct ggml_context * ctx,
-        struct ggml_tensor * k,
-        struct ggml_tensor * v,
-        struct ggml_tensor * r,
-        struct ggml_tensor * tf,
-        struct ggml_tensor * td,
-        struct ggml_tensor * state) {
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * r,
+        struct ggml_tensor  * tf,
+        struct ggml_tensor  * td,
+        struct ggml_tensor  * state) {
     GGML_ASSERT(ggml_is_contiguous(k));
     GGML_ASSERT(ggml_is_contiguous(v));
     GGML_ASSERT(ggml_is_contiguous(r));
@@ -7928,19 +7393,11 @@ struct ggml_tensor * ggml_rwkv_wkv(
         GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
     }
 
-    bool is_node = false;
-
-    if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
-        GGML_ABORT("fatal error"); // TODO: implement backward
-        is_node = true;
-    }
-
     // concat output and new_state
     const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
-    result->op   = GGML_OP_RWKV_WKV;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_RWKV_WKV;
     result->src[0] = k;
     result->src[1] = v;
     result->src[2] = r;
@@ -7955,23 +7412,16 @@ struct ggml_tensor * ggml_rwkv_wkv(
 
 static struct ggml_tensor * ggml_unary_impl(
         struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        enum ggml_unary_op op,
-        bool inplace) {
+        struct ggml_tensor  * a,
+        enum ggml_unary_op    op,
+        bool                  inplace) {
     GGML_ASSERT(ggml_is_contiguous_1(a));
 
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     ggml_set_op_params_i32(result, 0, (int32_t) op);
 
-    result->op   = GGML_OP_UNARY;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_UNARY;
     result->src[0] = a;
 
     return result;
@@ -7980,14 +7430,14 @@ static struct ggml_tensor * ggml_unary_impl(
 struct ggml_tensor * ggml_unary(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        enum ggml_unary_op op) {
+        enum ggml_unary_op    op) {
     return ggml_unary_impl(ctx, a, op, false);
 }
 
 struct ggml_tensor * ggml_unary_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        enum ggml_unary_op op) {
+        enum ggml_unary_op    op) {
     return ggml_unary_impl(ctx, a, op, true);
 }
 
@@ -7996,20 +7446,13 @@ struct ggml_tensor * ggml_unary_inplace(
 static struct ggml_tensor * ggml_map_unary_impl_f32(
         struct ggml_context        * ctx,
         struct ggml_tensor         * a,
-        const  ggml_unary_op_f32_t fun,
-        bool   inplace) {
-    bool is_node = false;
-
-    if (!inplace && a->grad) {
-        is_node = true;
-    }
-
+        const  ggml_unary_op_f32_t   fun,
+        bool                         inplace) {
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
 
-    result->op = GGML_OP_MAP_UNARY;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_MAP_UNARY;
     result->src[0] = a;
 
     return result;
@@ -8018,14 +7461,14 @@ static struct ggml_tensor * ggml_map_unary_impl_f32(
 struct ggml_tensor * ggml_map_unary_f32(
         struct ggml_context        * ctx,
         struct ggml_tensor         * a,
-        const  ggml_unary_op_f32_t fun) {
+        const  ggml_unary_op_f32_t   fun) {
     return ggml_map_unary_impl_f32(ctx, a, fun, false);
 }
 
 struct ggml_tensor * ggml_map_unary_inplace_f32(
         struct ggml_context        * ctx,
         struct ggml_tensor         * a,
-        const  ggml_unary_op_f32_t fun) {
+        const  ggml_unary_op_f32_t   fun) {
     return ggml_map_unary_impl_f32(ctx, a, fun, true);
 }
 
@@ -8035,22 +7478,15 @@ static struct ggml_tensor * ggml_map_binary_impl_f32(
         struct ggml_context         * ctx,
         struct ggml_tensor          * a,
         struct ggml_tensor          * b,
-        const  ggml_binary_op_f32_t fun,
-        bool   inplace) {
+        const  ggml_binary_op_f32_t   fun,
+        bool                          inplace) {
     GGML_ASSERT(ggml_are_same_shape(a, b));
 
-    bool is_node = false;
-
-    if (!inplace && (a->grad || b->grad)) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
 
-    result->op = GGML_OP_MAP_BINARY;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_MAP_BINARY;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -8061,7 +7497,7 @@ struct ggml_tensor * ggml_map_binary_f32(
         struct ggml_context         * ctx,
         struct ggml_tensor          * a,
         struct ggml_tensor          * b,
-        const  ggml_binary_op_f32_t fun) {
+        const  ggml_binary_op_f32_t   fun) {
     return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
 }
 
@@ -8069,7 +7505,7 @@ struct ggml_tensor * ggml_map_binary_inplace_f32(
         struct ggml_context         * ctx,
         struct ggml_tensor          * a,
         struct ggml_tensor          * b,
-        const  ggml_binary_op_f32_t fun) {
+        const  ggml_binary_op_f32_t   fun) {
     return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
 }
 
@@ -8079,19 +7515,12 @@ static struct ggml_tensor * ggml_map_custom1_impl_f32(
         struct ggml_context          * ctx,
         struct ggml_tensor           * a,
         const  ggml_custom1_op_f32_t   fun,
-        bool   inplace) {
-    bool is_node = false;
-
-    if (!inplace && a->grad) {
-        is_node = true;
-    }
-
+        bool                           inplace) {
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
 
-    result->op = GGML_OP_MAP_CUSTOM1_F32;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_MAP_CUSTOM1_F32;
     result->src[0] = a;
 
     return result;
@@ -8118,19 +7547,12 @@ static struct ggml_tensor * ggml_map_custom2_impl_f32(
         struct ggml_tensor           * a,
         struct ggml_tensor           * b,
         const  ggml_custom2_op_f32_t   fun,
-        bool   inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad || b->grad)) {
-        is_node = true;
-    }
-
+        bool                           inplace) {
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
 
-    result->op = GGML_OP_MAP_CUSTOM2_F32;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_MAP_CUSTOM2_F32;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -8161,19 +7583,12 @@ static struct ggml_tensor * ggml_map_custom3_impl_f32(
         struct ggml_tensor           * b,
         struct ggml_tensor           * c,
         const  ggml_custom3_op_f32_t   fun,
-        bool   inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad || b->grad || c->grad)) {
-        is_node = true;
-    }
-
+        bool                           inplace) {
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
 
-    result->op = GGML_OP_MAP_CUSTOM3_F32;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_MAP_CUSTOM3_F32;
     result->src[0] = a;
     result->src[1] = b;
     result->src[2] = c;
@@ -8201,26 +7616,20 @@ struct ggml_tensor * ggml_map_custom3_inplace_f32(
 
 // ggml_map_custom1
 struct ggml_map_custom1_op_params {
-    ggml_custom1_op_t fun;
-    int n_tasks;
-    void * userdata;
+    ggml_custom1_op_t  fun;
+    int                n_tasks;
+    void             * userdata;
 };
 
 static struct ggml_tensor * ggml_map_custom1_impl(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        const  ggml_custom1_op_t       fun,
-        int                            n_tasks,
-        void                         * userdata,
-        bool                           inplace) {
+        struct ggml_context      * ctx,
+        struct ggml_tensor       * a,
+        const  ggml_custom1_op_t   fun,
+        int                        n_tasks,
+        void                     * userdata,
+        bool                       inplace) {
     GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
 
-    bool is_node = false;
-
-    if (!inplace && a->grad) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     struct ggml_map_custom1_op_params params = {
@@ -8230,55 +7639,48 @@ static struct ggml_tensor * ggml_map_custom1_impl(
     };
     ggml_set_op_params(result, (const void *) &params, sizeof(params));
 
-    result->op = GGML_OP_MAP_CUSTOM1;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_MAP_CUSTOM1;
     result->src[0] = a;
 
     return result;
 }
 
 struct ggml_tensor * ggml_map_custom1(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        const  ggml_custom1_op_t       fun,
-        int                            n_tasks,
-        void                         * userdata) {
+        struct ggml_context      * ctx,
+        struct ggml_tensor       * a,
+        const  ggml_custom1_op_t   fun,
+        int                        n_tasks,
+        void                     * userdata) {
     return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, false);
 }
 
 struct ggml_tensor * ggml_map_custom1_inplace(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        const  ggml_custom1_op_t       fun,
-        int                            n_tasks,
-        void                         * userdata) {
+        struct ggml_context      * ctx,
+        struct ggml_tensor       * a,
+        const  ggml_custom1_op_t   fun,
+        int                        n_tasks,
+        void                     * userdata) {
     return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, true);
 }
 
 // ggml_map_custom2
 
 struct ggml_map_custom2_op_params {
-    ggml_custom2_op_t fun;
-    int n_tasks;
-    void * userdata;
+    ggml_custom2_op_t   fun;
+    int                 n_tasks;
+    void              * userdata;
 };
 
 static struct ggml_tensor * ggml_map_custom2_impl(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        const  ggml_custom2_op_t       fun,
-        int                            n_tasks,
-        void                         * userdata,
-        bool                           inplace) {
+        struct ggml_context      * ctx,
+        struct ggml_tensor       * a,
+        struct ggml_tensor       * b,
+        const  ggml_custom2_op_t   fun,
+        int                        n_tasks,
+        void                     * userdata,
+        bool                       inplace) {
     GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
 
-    bool is_node = false;
-
-    if (!inplace && (a->grad || b->grad)) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     struct ggml_map_custom2_op_params params = {
@@ -8288,8 +7690,7 @@ static struct ggml_tensor * ggml_map_custom2_impl(
     };
     ggml_set_op_params(result, (const void *) &params, sizeof(params));
 
-    result->op = GGML_OP_MAP_CUSTOM2;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_MAP_CUSTOM2;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -8297,22 +7698,22 @@ static struct ggml_tensor * ggml_map_custom2_impl(
 }
 
 struct ggml_tensor * ggml_map_custom2(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        const  ggml_custom2_op_t       fun,
-        int                            n_tasks,
-        void                         * userdata) {
+        struct ggml_context      * ctx,
+        struct ggml_tensor       * a,
+        struct ggml_tensor       * b,
+        const  ggml_custom2_op_t   fun,
+        int                        n_tasks,
+        void                     * userdata) {
     return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, false);
 }
 
 struct ggml_tensor * ggml_map_custom2_inplace(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        const  ggml_custom2_op_t       fun,
-        int                            n_tasks,
-        void                         * userdata) {
+        struct ggml_context      * ctx,
+        struct ggml_tensor       * a,
+        struct ggml_tensor       * b,
+        const  ggml_custom2_op_t   fun,
+        int                        n_tasks,
+        void                     * userdata) {
     return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, true);
 }
 
@@ -8325,22 +7726,16 @@ struct ggml_map_custom3_op_params {
 };
 
 static struct ggml_tensor * ggml_map_custom3_impl(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        struct ggml_tensor           * c,
-        const  ggml_custom3_op_t       fun,
-        int                            n_tasks,
-        void                         * userdata,
-        bool                           inplace) {
+        struct ggml_context      * ctx,
+        struct ggml_tensor       * a,
+        struct ggml_tensor       * b,
+        struct ggml_tensor       * c,
+        const  ggml_custom3_op_t   fun,
+        int                        n_tasks,
+        void                     * userdata,
+        bool                       inplace) {
     GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
 
-    bool is_node = false;
-
-    if (!inplace && (a->grad || b->grad || c->grad)) {
-        is_node = true;
-    }
-
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
     struct ggml_map_custom3_op_params params = {
@@ -8350,8 +7745,7 @@ static struct ggml_tensor * ggml_map_custom3_impl(
     };
     ggml_set_op_params(result, (const void *) &params, sizeof(params));
 
-    result->op = GGML_OP_MAP_CUSTOM3;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_MAP_CUSTOM3;
     result->src[0] = a;
     result->src[1] = b;
     result->src[2] = c;
@@ -8360,44 +7754,38 @@ static struct ggml_tensor * ggml_map_custom3_impl(
 }
 
 struct ggml_tensor * ggml_map_custom3(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        struct ggml_tensor           * c,
-        const  ggml_custom3_op_t       fun,
-        int                            n_tasks,
-        void                         * userdata) {
+        struct ggml_context      * ctx,
+        struct ggml_tensor       * a,
+        struct ggml_tensor       * b,
+        struct ggml_tensor       * c,
+        const  ggml_custom3_op_t   fun,
+        int                        n_tasks,
+        void                     * userdata) {
     return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, false);
 }
 
 struct ggml_tensor * ggml_map_custom3_inplace(
-        struct ggml_context          * ctx,
-        struct ggml_tensor           * a,
-        struct ggml_tensor           * b,
-        struct ggml_tensor           * c,
-        const  ggml_custom3_op_t       fun,
-        int                            n_tasks,
-        void                         * userdata) {
+        struct ggml_context      * ctx,
+        struct ggml_tensor       * a,
+        struct ggml_tensor       * b,
+        struct ggml_tensor       * c,
+        const  ggml_custom3_op_t   fun,
+        int                        n_tasks,
+        void                     * userdata) {
     return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
 }
 
 // ggml_cross_entropy_loss
 
 struct ggml_tensor * ggml_cross_entropy_loss(
-        struct ggml_context         * ctx,
-        struct ggml_tensor          * a,
-        struct ggml_tensor          * b) {
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
     GGML_ASSERT(ggml_are_same_shape(a, b));
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        is_node = true;
-    }
 
     struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
 
-    result->op   = GGML_OP_CROSS_ENTROPY_LOSS;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_CROSS_ENTROPY_LOSS;
     result->src[0] = a;
     result->src[1] = b;
 
@@ -8407,17 +7795,16 @@ struct ggml_tensor * ggml_cross_entropy_loss(
 // ggml_cross_entropy_loss_back
 
 struct ggml_tensor * ggml_cross_entropy_loss_back(
-        struct ggml_context         * ctx,
-        struct ggml_tensor          * a,
-        struct ggml_tensor          * b,
-        struct ggml_tensor          * c) {
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        struct ggml_tensor  * c) {
     GGML_ASSERT(ggml_are_same_shape(a, b));
     GGML_ASSERT(ggml_is_scalar(c));
 
     struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
 
-    result->op   = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
-    result->grad = NULL;
+    result->op     = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
     result->src[0] = a;
     result->src[1] = b;
     result->src[2] = c;
@@ -8435,7 +7822,7 @@ struct ggml_tensor * ggml_opt_step_adamw(
         float                 beta2,
         float                 eps,
         float                 wd) {
-    GGML_ASSERT(a->grad);
+    GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
     GGML_ASSERT(alpha >  0.0f);
     GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
     GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
@@ -8444,13 +7831,6 @@ struct ggml_tensor * ggml_opt_step_adamw(
 
     struct ggml_tensor * result = ggml_view_tensor(ctx, a);
 
-    result->op   = GGML_OP_OPT_STEP_ADAMW;
-    result->grad = NULL;
-    result->src[0] = a;
-    result->src[1] = a->grad;
-    result->src[2] = ggml_dup_tensor(ctx, a->grad);
-    result->src[3] = ggml_dup_tensor(ctx, a->grad);
-
     const int64_t iter = 1;
     memcpy(&result->op_params[0], &iter, sizeof(int64_t));
     ggml_set_op_params_f32(result, 2, alpha);
@@ -8459,26 +7839,17 @@ struct ggml_tensor * ggml_opt_step_adamw(
     ggml_set_op_params_f32(result, 5, eps);
     ggml_set_op_params_f32(result, 6, wd);
 
+    result->op     = GGML_OP_OPT_STEP_ADAMW;
+    result->src[0] = a;
+    result->src[1] = a->grad;
+    result->src[2] = ggml_dup_tensor(ctx, a);
+    result->src[3] = ggml_dup_tensor(ctx, a);
+
     return result;
 }
 
 ////////////////////////////////////////////////////////////////////////////////
 
-void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) {
-    tensor->flags |= GGML_TENSOR_FLAG_PARAM;
-
-    GGML_ASSERT(tensor->grad == NULL);
-    tensor->grad = ggml_dup_tensor(ctx, tensor);
-    ggml_format_name(tensor->grad, "%s (grad)", tensor->name);
-}
-
-void ggml_set_loss(struct ggml_tensor * tensor) {
-    GGML_ASSERT(ggml_is_scalar(tensor));
-    GGML_ASSERT(tensor->type == GGML_TYPE_F32);
-    GGML_ASSERT(tensor->grad);
-    tensor->flags |= GGML_TENSOR_FLAG_LOSS;
-}
-
 // ggml_compute_forward_dup
 
 static void ggml_compute_forward_dup_same_cont(
@@ -18198,7 +17569,7 @@ void ggml_build_backward_gradient_checkpointing(
         struct ggml_tensor  * * checkpoints,
         int                     n_checkpoints) {
     ggml_graph_cpy(gf, gb_tmp);
-    ggml_build_backward_expand(ctx, gf, gb_tmp, false, true);
+    ggml_build_backward_expand(ctx, gf, gb_tmp, false);
 
     if (n_checkpoints <= 0) {
         ggml_graph_cpy(gb_tmp, gb);
@@ -18850,7 +18221,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             ggml_soft_max_back(ctx, tensor->grad, tensor),
                         zero_table, acc_table);
                 }
-
+                GGML_ASSERT((!src1 || !src1->grad) && "backward pass for softmax mask not implemented");
             } break;
         case GGML_OP_SOFT_MAX_BACK:
             {
@@ -18891,6 +18262,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 beta_slow),
                             zero_table, acc_table);
                 }
+                GGML_ASSERT((!src2 || !src2->grad) && "gradients for freq factors not implemented");
             } break;
         case GGML_OP_ROPE_BACK:
             {
@@ -19012,6 +18384,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             }
         case GGML_OP_FLASH_ATTN_EXT:
             {
+                GGML_ABORT("FA backward pass not adapted after rework");
                 struct ggml_tensor * flash_grad = NULL;
                 if (src0->grad || src1->grad || tensor->src[2]->grad) {
                     int32_t t = ggml_get_op_params_i32(tensor, 0);
@@ -19186,6 +18559,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                     tensor->grad),
                                 zero_table, acc_table);
                 }
+                GGML_ASSERT(!src1->grad && "backward pass for labels not implemented");
             } break;
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
             {
@@ -19236,7 +18610,7 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
         }
     }
 
-    if (node->op == GGML_OP_NONE && node->grad == NULL) {
+    if (node->op == GGML_OP_NONE && !(node->flags & GGML_TENSOR_FLAG_PARAM)) {
         // reached a leaf node, not part of the gradient graph (e.g. a constant)
         GGML_ASSERT(cgraph->n_leafs < cgraph->size);
 
@@ -19254,9 +18628,6 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
         }
 
         cgraph->nodes[cgraph->n_nodes] = node;
-        if (cgraph->grads) {
-            cgraph->grads[cgraph->n_nodes] = node->grad;
-        }
         cgraph->n_nodes++;
     }
 }
@@ -19284,20 +18655,58 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
     ggml_build_forward_impl(cgraph, tensor, true);
 }
 
-void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep) {
+void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate) {
     GGML_ASSERT(gf->n_nodes > 0);
     GGML_ASSERT(gf->grads);
 
-    // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph
-    if (keep) {
-        for (int i = 0; i < gf->n_nodes; i++) {
-            struct ggml_tensor * node = gf->nodes[i];
+    for (int i = 0; i < gf->n_nodes; ++i) {
+        struct ggml_tensor * node = gf->nodes[i];
+
+        bool needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM;
+        bool ignore_src[GGML_MAX_SRC] = {false};
+        switch (node->op) {
+            // gradients in node->src[0] for one reason or another have no effect on output gradients
+            case GGML_OP_IM2COL:      // only used for its shape
+            case GGML_OP_IM2COL_BACK: // same as IM2COL
+                ignore_src[0] = true;
+                break;
+            case GGML_OP_UNARY: {
+                const enum ggml_unary_op uop = ggml_get_unary_op(node);
+                // SGN and STEP unary ops are piecewise constant
+                if (uop == GGML_UNARY_OP_SGN || uop == GGML_UNARY_OP_STEP) {
+                    ignore_src[0] = true;
+                }
+            } break;
+
+            // gradients in node->src[1] for one reason or another have no effect on output gradients
+            case GGML_OP_CPY:           // gradients in CPY target  are irrelevant
+            case GGML_OP_GET_ROWS:      // row indices not differentiable
+            case GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS
+            case GGML_OP_ROPE:          // positions not differentiable
+                ignore_src[1] = true;
+                break;
 
-            if (node->grad) {
-                node->grad = ggml_dup_tensor(ctx, node);
-                gf->grads[i] = node->grad;
+            default:
+                break;
+        }
+        for (int j = 0; j < GGML_MAX_SRC; ++j) {
+            if (!node->src[j] || !node->src[j]->grad || ignore_src[j]) {
+                continue;
             }
+            GGML_ASSERT(node->src[j]->type == GGML_TYPE_F32 || node->src[j]->type == GGML_TYPE_F16);
+            needs_grad = true;
+            break;
+        }
+        if (!needs_grad) {
+            continue;
         }
+
+        // inplace operations are currently not supported
+        GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
+            node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
+
+        // create a new tensor with the same type and shape as the node and set it as grad
+        node->grad = ggml_dup_tensor(ctx, node);
     }
 
     // keep tables of original gradients for replacement/accumulation logic
@@ -22162,8 +21571,6 @@ enum ggml_opt_result ggml_opt(
         struct ggml_context * ctx,
         struct ggml_opt_params params,
         struct ggml_tensor * f) {
-    GGML_ASSERT(f->grad && "ggml_set_param called for at least one parent tensor.");
-
     bool free_ctx = false;
     if (ctx == NULL) {
         struct ggml_init_params params_ctx = {
@@ -22204,7 +21611,7 @@ enum ggml_opt_result ggml_opt_resume(
     ggml_build_forward_expand(gf, f);
 
     struct ggml_cgraph * gb = ggml_graph_dup(ctx, gf);
-    ggml_build_backward_expand(ctx, gf, gb, false, true);
+    ggml_build_backward_expand(ctx, gf, gb, false);
 
     return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL);
 }
@@ -22257,6 +21664,17 @@ void ggml_set_output(struct ggml_tensor * tensor) {
     tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
 }
 
+void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) {
+    GGML_UNUSED(ctx); // TODO: remove this parameter
+    tensor->flags |= GGML_TENSOR_FLAG_PARAM;
+}
+
+void ggml_set_loss(struct ggml_tensor * tensor) {
+    GGML_ASSERT(ggml_is_scalar(tensor));
+    GGML_ASSERT(tensor->type == GGML_TYPE_F32);
+    tensor->flags |= GGML_TENSOR_FLAG_LOSS;
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 
 void ggml_quantize_init(enum ggml_type type) {
index d2cfe06b592cfc25a2753bc63462c1947d2ca092..5c78b6704da57d06896499d2f4f5a00a993fee8a 100644 (file)
@@ -1,6 +1,6 @@
 // This file defines tests for various GGML ops and backends.
 // For the forward pass it asserts that the results of multiple backends computing the same GGML ops are consistent.
-// For the backwards pass it asserts that the gradients from backpropagation are consistent
+// For the backward pass it asserts that the gradients from backpropagation are consistent
 // with the gradients obtained via the method of finite differences ("grad" mode, this is optional).
 // It is also possible to check the performance ("perf" mode).
 //
@@ -740,7 +740,7 @@ struct test_case {
 
         ggml_tensor * out = build_graph(ctx);
 
-        if (op_name != nullptr && op_desc(out) != op_name) {
+        if ((op_name != nullptr && op_desc(out) != op_name) || out->op == GGML_OP_OPT_STEP_ADAMW) {
             //printf("  %s: skipping\n", op_desc(out).c_str());
             ggml_free(ctx);
             return true;
@@ -749,11 +749,6 @@ struct test_case {
         printf("  %s(%s): ", op_desc(out).c_str(), vars().c_str());
         fflush(stdout);
 
-        if (out->grad == nullptr) {
-            printf("backwards pass not supported \n");
-            ggml_free(ctx);
-            return true;
-        }
         if (out->type != GGML_TYPE_F32) {
             ggml_free(ctx);
             printf("not supported [%s->type != FP32]\n", out->name);
@@ -762,18 +757,26 @@ struct test_case {
 
         // check if the backend supports the ops
         bool supported = true;
+        bool any_params = false;
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
             if (!ggml_backend_supports_op(backend, t)) {
                 printf("not supported [%s] ", ggml_backend_name(backend));
                 supported = false;
                 break;
             }
-            if ((t->flags & GGML_TENSOR_FLAG_PARAM) && t->type != GGML_TYPE_F32) {
-                printf("not supported [%s->type != FP32] ", t->name);
-                supported = false;
-                break;
+            if ((t->flags & GGML_TENSOR_FLAG_PARAM)) {
+                any_params = true;
+                if (t->type != GGML_TYPE_F32) {
+                    printf("not supported [%s->type != FP32] ", t->name);
+                    supported = false;
+                    break;
+                }
             }
         }
+        if (!any_params) {
+            printf("not supported [%s] \n", op_name);
+            supported = false;
+        }
         if (!supported) {
             printf("\n");
             ggml_free(ctx);
@@ -801,7 +804,7 @@ struct test_case {
 
         ggml_build_forward_expand(gf, out);
         ggml_graph_cpy(gf, gb);
-        ggml_build_backward_expand(ctx, gf, gb, false, false);
+        ggml_build_backward_expand(ctx, gf, gb, false);
         if (expect.size() != 1 || expect[0] != 0.0f) {
             GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf));
             for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
@@ -984,7 +987,7 @@ struct test_example : public test_case {
     }
     // In order to also check the gradients for your op, add calls like ggml_set_param(ctx, a)
     // immediately after you create the tensors.
-    // This is optional and only makes sense if a backwards pass has actually been implemented for the new op.
+    // This is optional and only makes sense if a backward pass has actually been implemented for the new op.
 };
 
 
@@ -1223,7 +1226,7 @@ struct test_set : public test_case {
             offset += ((ne_dst[i] - ne[i])/2)*dst->nb[i];
         }
         ggml_tensor * out = ggml_set(ctx, dst, src,
-            // The backwards pass requires setting a contiguous region:
+            // The backward pass requires setting a contiguous region:
             src->nb[1], src->nb[2], src->nb[3], offset);
         ggml_set_name(out, "out");
 
@@ -1335,7 +1338,7 @@ struct test_bin_bcast : public test_case {
         ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
         ggml_set_name(b, "b");
 
-        // The backwards pass supports broadcasting only for GGML_ADD:
+        // The backward pass supports broadcasting only for GGML_ADD:
         const bool grad_supported = op == ggml_add || ggml_are_same_shape(a, b);
         if (grad_supported) {
             ggml_set_param(ctx, a);
@@ -1830,7 +1833,7 @@ struct test_log : public test_case {
 
     void initialize_tensors(ggml_context * ctx) override {
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
-            // log(1) == 0, cluster values there to keep the sum low for better precision in the backwards pass:
+            // log(1) == 0, cluster values there to keep the sum low for better precision in the backward pass:
             init_tensor_uniform(t, 0.9f, 1.1f);
         }
     }
@@ -3257,7 +3260,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
     test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
 
-    for (int ne3 : {1, 3}) { // CUDA backwards pass only supports ne3 == 1
+    for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
         test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));
         test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1}));
         test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 2, 1, 1}));
index 2ef606d2c3591cbdd85d3ee120430e740875b2fc..2200ad93dbfc5f466281ddda1c09816289dc2fcd 100644 (file)
@@ -240,12 +240,14 @@ static bool check_gradient(
     struct ggml_cgraph * gb = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
     ggml_build_forward_expand(gf, f);
     ggml_graph_cpy(gf, gb);
-    ggml_build_backward_expand(ctx0, gf, gb, false, false);
+    ggml_build_backward_expand(ctx0, gf, gb, false);
 
     ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
 
-    ggml_graph_reset  (gf);
-    ggml_set_f32      (f->grad, 1.0f);
+    ggml_graph_reset(gb);
+    if (f->grad) {
+        ggml_set_f32(f->grad, 1.0f);
+    }
 
     ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
 
@@ -298,8 +300,10 @@ static bool check_gradient(
             ggml_set_f32_1d(x[i], k, x0);
 
             // compute gradient using backward graph
-            ggml_graph_reset  (gf);
-            ggml_set_f32      (f->grad, 1.0f);
+            ggml_graph_reset(gb);
+            if (f->grad) {
+                ggml_set_f32(f->grad, 1.0f);
+            }
 
             ggml_graph_compute_with_ctx(ctx0, gb, n_threads);