struct ggml_tensor * ggml_opt_step_adamw(
struct ggml_context * ctx,
struct ggml_tensor * a,
+ struct ggml_tensor * grad,
float alpha,
float beta1,
float beta2,
float eps,
float wd) {
GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
+ GGML_ASSERT(ggml_are_same_shape(a, grad));
GGML_ASSERT(alpha > 0.0f);
GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
result->op = GGML_OP_OPT_STEP_ADAMW;
result->src[0] = a;
- result->src[1] = a->grad;
- result->src[2] = ggml_dup_tensor(ctx, a);
- result->src[3] = ggml_dup_tensor(ctx, a);
+ result->src[1] = grad;
+ result->src[2] = ggml_dup_tensor(ctx, grad);
+ result->src[3] = ggml_dup_tensor(ctx, grad);
return result;
}
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
- struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd);
+ struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd);
ggml_build_forward_expand(gb, opt_step);
}
}