]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
mnist: fix segmentation fault (ggml/1227)
authorJohannes Gäßler <redacted>
Mon, 19 May 2025 07:33:35 +0000 (09:33 +0200)
committerGeorgi Gerganov <redacted>
Mon, 19 May 2025 10:29:56 +0000 (13:29 +0300)
ggml/include/ggml-opt.h
ggml/src/ggml-opt.cpp

index da0c24b46fed96739dab53670ed72b243ca7dd65..74ec080a055eaeb4134080c6a77f25feffc0353e 100644 (file)
@@ -128,6 +128,8 @@ extern "C" {
     // set gradients to zero, initilize loss, and optionally reset the optimizer
     GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
 
+    GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically
+
     // get underlying tensors that store data
     // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
     GGML_API struct ggml_tensor * ggml_opt_inputs(  ggml_opt_context_t opt_ctx); // forward graph input tensor
index 58d77578f458dafa9c97a6099f0618f716c7f601..a3c82d6757714b64f4a46dfb0e5a7a62c22ee566 100644 (file)
@@ -576,6 +576,10 @@ void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) {
     }
 }
 
+bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx) {
+    return opt_ctx->static_graphs;
+}
+
 struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) {
     return opt_ctx->inputs;
 }
@@ -842,6 +846,7 @@ void ggml_opt_epoch(
         int64_t                 idata_split,
         ggml_opt_epoch_callback callback_train,
         ggml_opt_epoch_callback callback_eval) {
+    GGML_ASSERT(ggml_opt_static_graphs(opt_ctx) && "ggml_opt_epoch requires static graphs");
     struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx);
     struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
     struct ggml_tensor * data   = ggml_opt_dataset_data(dataset);