]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
test: fix OPT_STEP_ADAMW for test-backend-ops (#974)
authorJohannes Gäßler <redacted>
Mon, 30 Sep 2024 07:55:23 +0000 (09:55 +0200)
committerGitHub <redacted>
Mon, 30 Sep 2024 07:55:23 +0000 (09:55 +0200)
include/ggml.h
src/ggml.c
tests/test-backend-ops.cpp

index a8a74bee13de33b066547bace9373e07332b7f87..ce3d92cb2e0f060dc9dfc39ea3e16b0943a7fb67 100644 (file)
@@ -2052,6 +2052,7 @@ extern "C" {
     GGML_API struct ggml_tensor * ggml_opt_step_adamw(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
+            struct ggml_tensor  * grad,
             float                 alpha,
             float                 beta1,
             float                 beta2,
index aac4e3a7b0c676a84fed4b06d22a33618e386441..bcbc32d913dec87393c8db7b667802637cb11b8c 100644 (file)
@@ -7818,12 +7818,14 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
 struct ggml_tensor * ggml_opt_step_adamw(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
+        struct ggml_tensor  * grad,
         float                 alpha,
         float                 beta1,
         float                 beta2,
         float                 eps,
         float                 wd) {
     GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
+    GGML_ASSERT(ggml_are_same_shape(a, grad));
     GGML_ASSERT(alpha >  0.0f);
     GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
     GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
@@ -7842,9 +7844,9 @@ struct ggml_tensor * ggml_opt_step_adamw(
 
     result->op     = GGML_OP_OPT_STEP_ADAMW;
     result->src[0] = a;
-    result->src[1] = a->grad;
-    result->src[2] = ggml_dup_tensor(ctx, a);
-    result->src[3] = ggml_dup_tensor(ctx, a);
+    result->src[1] = grad;
+    result->src[2] = ggml_dup_tensor(ctx, grad);
+    result->src[3] = ggml_dup_tensor(ctx, grad);
 
     return result;
 }
@@ -18769,7 +18771,7 @@ void ggml_build_opt_adamw(
 
         if (node->flags & GGML_TENSOR_FLAG_PARAM) {
             GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
-            struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd);
+            struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd);
             ggml_build_forward_expand(gb, opt_step);
         }
     }
index 5c78b6704da57d06896499d2f4f5a00a993fee8a..95d983aa083c3811772949805d0bf7d5a08c36ce 100644 (file)
@@ -2751,7 +2751,10 @@ struct test_opt_step_adamw : public test_case {
         ggml_set_param(ctx, a); // Despite tensor a having gradients the output tensor will not.
         ggml_set_name(a, "a");
 
-        ggml_tensor * out = ggml_opt_step_adamw(ctx, a, alpha, beta1, beta2, eps, wd);
+        ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
+        ggml_set_name(grad, "grad");
+
+        ggml_tensor * out = ggml_opt_step_adamw(ctx, a, grad, alpha, beta1, beta2, eps, wd);
         ggml_set_name(out, "out");
 
         return out;