void mnist_model_train(mnist_model & model, ggml_opt_dataset_t dataset, const int nepoch, const float val_split) {
ggml_opt_fit(model.backend_sched, model.ctx_compute, model.images, model.logits, dataset,
- GGML_OPT_LOSS_TYPE_CROSS_ENTROPY, ggml_opt_get_default_optimizer_params, nepoch, model.nbatch_logical, val_split, false);
+ GGML_OPT_LOSS_TYPE_CROSS_ENTROPY, GGML_OPT_OPTIMIZER_TYPE_ADAMW, ggml_opt_get_default_optimizer_params, nepoch, model.nbatch_logical, val_split, false);
}
void mnist_model_save(mnist_model & model, const std::string & fname) {