// set gradients to zero, initilize loss, and optionally reset the optimizer
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
+ GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically
+
// get underlying tensors that store data
// if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
}
}
+bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx) {
+ return opt_ctx->static_graphs;
+}
+
struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) {
return opt_ctx->inputs;
}
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval) {
+ GGML_ASSERT(ggml_opt_static_graphs(opt_ctx) && "ggml_opt_epoch requires static graphs");
struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx);
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
struct ggml_tensor * data = ggml_opt_dataset_data(dataset);