From: Johannes Gäßler Date: Mon, 19 May 2025 07:33:35 +0000 (+0200) Subject: mnist: fix segmentation fault (#1227) X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=7c06c10c532a6cda913c17fc56341e8880ae341d;p=pkg%2Fggml%2Fsources%2Fggml mnist: fix segmentation fault (#1227) --- diff --git a/examples/mnist/mnist-common.cpp b/examples/mnist/mnist-common.cpp index 6bf369f5..a303bcec 100644 --- a/examples/mnist/mnist-common.cpp +++ b/examples/mnist/mnist-common.cpp @@ -385,7 +385,10 @@ ggml_opt_result_t mnist_model_eval(mnist_model & model, ggml_opt_dataset_t datas ggml_opt_result_t result = ggml_opt_result_init(); ggml_opt_params params = ggml_opt_default_params(model.backend_sched, GGML_OPT_LOSS_TYPE_CROSS_ENTROPY); - params.build_type = GGML_OPT_BUILD_TYPE_FORWARD; + params.ctx_compute = model.ctx_compute; + params.inputs = model.images; + params.outputs = model.logits; + params.build_type = GGML_OPT_BUILD_TYPE_FORWARD; ggml_opt_context_t opt_ctx = ggml_opt_init(params); { diff --git a/include/ggml-opt.h b/include/ggml-opt.h index da0c24b4..74ec080a 100644 --- a/include/ggml-opt.h +++ b/include/ggml-opt.h @@ -128,6 +128,8 @@ extern "C" { // set gradients to zero, initilize loss, and optionally reset the optimizer GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer); + GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically + // get underlying tensors that store data // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor diff --git a/src/ggml-opt.cpp b/src/ggml-opt.cpp index 58d77578..a3c82d67 100644 --- a/src/ggml-opt.cpp +++ b/src/ggml-opt.cpp @@ -576,6 +576,10 @@ void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) { } } +bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx) { + return opt_ctx->static_graphs; +} + struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) { return opt_ctx->inputs; } @@ -842,6 +846,7 @@ void ggml_opt_epoch( int64_t idata_split, ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval) { + GGML_ASSERT(ggml_opt_static_graphs(opt_ctx) && "ggml_opt_epoch requires static graphs"); struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx); struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); struct ggml_tensor * data = ggml_opt_dataset_data(dataset);