From: Johannes Gäßler Date: Sat, 5 Oct 2024 16:38:01 +0000 (+0200) Subject: examples: add dataset, data shuffling to MNIST (#982) X-Git-Tag: upstream/0.0.1642~302 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=130bc179471fb5d69edd313718e9e70e84ca67d9;p=pkg%2Fggml%2Fsources%2Fggml examples: add dataset, data shuffling to MNIST (#982) --- diff --git a/examples/mnist/mnist-common.cpp b/examples/mnist/mnist-common.cpp index c70dfd7f..0c4ca07d 100644 --- a/examples/mnist/mnist-common.cpp +++ b/examples/mnist/mnist-common.cpp @@ -14,7 +14,7 @@ #include #include -bool mnist_image_load(const std::string & fname, float * buf, const int nex) { +bool mnist_image_load(const std::string & fname, mnist_dataset & dataset) { auto fin = std::ifstream(fname, std::ios::binary); if (!fin) { fprintf(stderr, "failed to open images file %s\n", fname.c_str()); @@ -23,8 +23,9 @@ bool mnist_image_load(const std::string & fname, float * buf, const int nex) { fin.seekg(16); uint8_t image[MNIST_NINPUT]; + float * buf = ggml_get_data_f32(dataset.data); - for (int iex = 0; iex < nex; ++iex) { + for (int iex = 0; iex < dataset.nex; ++iex) { fin.read((char *) image, sizeof(image)); for (int i = 0; i < MNIST_NINPUT; ++i) { @@ -35,12 +36,12 @@ bool mnist_image_load(const std::string & fname, float * buf, const int nex) { return true; } -void mnist_image_print(FILE * stream, const float * image) { - static_assert(MNIST_NINPUT == 28*28, "Unexpected MNIST_NINPUT"); +void mnist_image_print(FILE * stream, mnist_dataset & dataset, const int iex) { + const float * image = ggml_get_data_f32(dataset.data) + iex*MNIST_NINPUT; - for (int row = 0; row < 28; row++) { - for (int col = 0; col < 28; col++) { - const int rgb = roundf(255.0f * image[row*28 + col]); + for (int row = 0; row < MNIST_HW; row++) { + for (int col = 0; col < MNIST_HW; col++) { + const int rgb = roundf(255.0f * image[row*MNIST_HW + col]); #ifdef _WIN32 fprintf(stream, "%s", rgb >= 220 ? "##" : "__"); // Represented via text. #else @@ -51,7 +52,7 @@ void mnist_image_print(FILE * stream, const float * image) { } } -bool mnist_label_load(const std::string & fname, float * buf, const int nex) { +bool mnist_label_load(const std::string & fname, mnist_dataset & dataset) { auto fin = std::ifstream(fname, std::ios::binary); if (!fin) { fprintf(stderr, "failed to open labels file %s\n", fname.c_str()); @@ -60,8 +61,9 @@ bool mnist_label_load(const std::string & fname, float * buf, const int nex) { fin.seekg(8); uint8_t label; + float * buf = ggml_get_data_f32(dataset.labels); - for (int iex = 0; iex < nex; ++iex) { + for (int iex = 0; iex < dataset.nex; ++iex) { fin.read((char *) &label, sizeof(label)); for (int i = 0; i < MNIST_NCLASSES; ++i) { @@ -500,7 +502,7 @@ void mnist_model_build(mnist_model & model, const int nbatch_logical, const int GGML_ASSERT(model.acc_count->ne[3] == 1); } -mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, const float * labels, const int nex) { +mnist_eval_result mnist_model_eval(mnist_model & model, mnist_dataset & dataset) { mnist_eval_result result; struct ggml_cgraph * gf = ggml_new_graph(model.ctx_compute); @@ -522,10 +524,10 @@ mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, co GGML_ASSERT(sizeof(tmp_pred[0])*tmp_pred.size() == ggml_nbytes(model.pred)); GGML_ASSERT(sizeof(tmp_acc_count) == ggml_nbytes(model.acc_count)); - GGML_ASSERT(nex % model.nbatch_physical == 0); - for (int iex0 = 0; iex0 < nex; iex0 += model.nbatch_physical) { - ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT, 0, ggml_nbytes(model.images)); - ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels)); + GGML_ASSERT(dataset.nex % model.nbatch_physical == 0); + const int nbatches = dataset.nex/model.nbatch_physical; + for (int ibatch = 0; ibatch < nbatches; ++ibatch) { + dataset.get_batch(model.images, model.labels, ibatch); ggml_backend_graph_compute(model.backend, gf); @@ -542,17 +544,21 @@ mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, co const int64_t t_total_us = ggml_time_us() - t_start_us; const double t_total_ms = 1e-3*t_total_us; fprintf(stderr, "%s: model evaluation on %d images took %.2lf ms, %.2lf us/image\n", - __func__, nex, t_total_ms, (double) t_total_us/nex); + __func__, (int)dataset.nex, t_total_ms, (double) t_total_us/dataset.nex); } result.success = true; return result; } -void mnist_model_train(mnist_model & model, const float * images, const float * labels, const int nex, const int nepoch, const float val_split) { +void mnist_model_train(mnist_model & model, mnist_dataset & dataset, const int nepoch, const float val_split) { const int64_t t_start_us = ggml_time_us(); - const bool accumulate = model.nbatch_physical != model.nbatch_logical; + const int opt_period = model.nbatch_logical / model.nbatch_physical; + const int nbatches_logical = dataset.nex / model.nbatch_logical; + const int nbatches_physical = dataset.nex / model.nbatch_physical; + const int ibatch_split = ((int)((1.0f - val_split)*nbatches_logical))*opt_period; // train <-> val split index (physical) + const int ishard_split = ibatch_split * model.nbatch_physical/dataset.shard_size; // gf == graph forward, forward pass only. struct ggml_cgraph * gf = ggml_new_graph_custom(model.ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass. @@ -563,7 +569,7 @@ void mnist_model_train(mnist_model & model, const float * images, const float * // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients. struct ggml_cgraph * gb_grad = ggml_graph_dup(model.ctx_compute, gf); - ggml_build_backward_expand(model.ctx_compute, gf, gb_grad, accumulate); + ggml_build_backward_expand(model.ctx_compute, gf, gb_grad, /*accumulate =*/ opt_period > 1); // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step. struct ggml_cgraph * gb_opt = ggml_graph_dup(model.ctx_compute, gb_grad); @@ -572,13 +578,15 @@ void mnist_model_train(mnist_model & model, const float * images, const float * model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend); ggml_graph_reset(gb_opt); // Set gradients to zero, reset optimizer. - const int iex_split = ((int)((1.0f - val_split)*nex) / model.nbatch_logical) * model.nbatch_logical; + dataset.shuffle(-1); // Shuffle all data (train + validation). for (int epoch = 0; epoch < nepoch; ++epoch) { fprintf(stderr, "%s: epoch %02d start...", __func__, epoch); const int64_t t_start_us = ggml_time_us(); - int iex0 = 0; + dataset.shuffle(ishard_split); // Shuffle only the training data, keeping training and validation set separate. + + int ibatch_physical = 0; float tmp_loss; std::vector tmp_pred(model.nbatch_physical); @@ -589,13 +597,12 @@ void mnist_model_train(mnist_model & model, const float * images, const float * GGML_ASSERT(sizeof(tmp_acc_count) == ggml_nbytes(model.acc_count)); mnist_eval_result result_train; - for (; iex0 < iex_split; iex0 += model.nbatch_physical) { - ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT, 0, ggml_nbytes(model.images)); - ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels)); + for (; ibatch_physical < ibatch_split; ++ibatch_physical) { + dataset.get_batch(model.images, model.labels, ibatch_physical); - // With a period of nbatch_logical/nbatch_physical iterations: - if ((iex0 + model.nbatch_physical) % model.nbatch_logical != 0) { - // For the first nbatch_logical/nbatch_physical - 1 iterations, only calculate gradients and accumulate them: + // With a period of opt_period == nbatch_logical/nbatch_physical iterations: + if ((ibatch_physical + 1) % opt_period != 0) { + // For the first opt_period - 1 iterations, only calculate gradients and accumulate them: ggml_backend_graph_compute(model.backend, gb_grad); } else { // For the last iteration, calculate gradients and also apply the optimizer: @@ -614,9 +621,8 @@ void mnist_model_train(mnist_model & model, const float * images, const float * } mnist_eval_result result_val; - for (; iex0 < nex; iex0 += model.nbatch_physical) { - ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT, 0, ggml_nbytes(model.images)); - ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels)); + for (; ibatch_physical < nbatches_physical; ++ibatch_physical) { + dataset.get_batch(model.images, model.labels, ibatch_physical); ggml_backend_graph_compute(model.backend, gf); // For the validation set, only the forward pass is needed. @@ -639,7 +645,7 @@ void mnist_model_train(mnist_model & model, const float * images, const float * fprintf(stderr, "done, took %.2lfs, train_loss=%.6lf, train_acc=%.2f%%", t_epoch_s, loss_mean, percent_correct); } - if (iex_split < nex) { + if (ibatch_split < nbatches_physical) { const std::pair loss = mnist_loss(result_val); const std::pair acc = mnist_accuracy(result_val); @@ -731,11 +737,14 @@ extern "C" { int wasm_eval(uint8_t * digitPtr) { std::vector digit(digitPtr, digitPtr + MNIST_NINPUT); - std::vector labels(MNIST_NCLASSES); + + struct mnist_dataset dataset(1, 1); + memcpy(dataset.data->data, digitPtr, ggml_nbytes(dataset.data)); + ggml_set_zero(dataset.labels); // The labels are not needed. mnist_model model = mnist_model_init_from_file("mnist-f32.gguf", "CPU"); mnist_model_build(model, 1, 1); - mnist_eval_result result = mnist_model_eval(model, digit.data(), labels.data(), 1); + mnist_eval_result result = mnist_model_eval(model, dataset); return result.pred[0]; } diff --git a/examples/mnist/mnist-common.h b/examples/mnist/mnist-common.h index 1c6fe783..f0d24804 100644 --- a/examples/mnist/mnist-common.h +++ b/examples/mnist/mnist-common.h @@ -1,4 +1,6 @@ +#include #include +#include #include #include #include @@ -25,6 +27,64 @@ static_assert(MNIST_NTEST % MNIST_NBATCH_LOGICAL == 0, "MNIST_NTRAIN % MNIST_NB // NCB = number of channels base #define MNIST_CNN_NCB 8 +struct mnist_dataset { + struct ggml_context * ctx; + struct ggml_tensor * data; + struct ggml_tensor * labels; + + int64_t nex; + int64_t shard_size; + size_t nbs_data; + size_t nbs_labels; + + std::vector permutation; + std::mt19937 rng; + + mnist_dataset(const int64_t nex, const int64_t shard_size) : nex(nex), shard_size(shard_size) { + const size_t nbytes_images = nex*MNIST_NINPUT *sizeof(float) + ggml_tensor_overhead(); + const size_t nbytes_labels = nex*MNIST_NCLASSES*sizeof(float) + ggml_tensor_overhead(); + struct ggml_init_params params = { + /*.mem_size =*/ nbytes_images + nbytes_labels, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ false, + }; + ctx = ggml_init(params); + + data = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, MNIST_HW, MNIST_HW, nex); + labels = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, MNIST_NCLASSES, nex); + + nbs_data = ggml_nbytes(data) * shard_size/nex; + nbs_labels = ggml_nbytes(labels) * shard_size/nex; + + permutation.resize(nex/shard_size); + for (size_t i = 0; i < permutation.size(); ++i) { + permutation[i] = i; + } + } + + ~mnist_dataset() { + ggml_free(ctx); + } + + void shuffle(const size_t ishard_max) { + if (ishard_max < permutation.size()) { + std::shuffle(permutation.begin(), permutation.begin() + ishard_max, rng); + return; + } + std::shuffle(permutation.begin(), permutation.end(), rng); + } + + void get_batch(struct ggml_tensor * data_batch, struct ggml_tensor * labels_batch, const int64_t ibatch) { + const int64_t shards_per_batch = ggml_nbytes(data_batch) / nbs_data; + for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) { + const int64_t ishard = permutation[ibatch*shards_per_batch + ishard_batch]; + + ggml_backend_tensor_set(data_batch, (const char *) data->data + ishard*nbs_data, ishard_batch*nbs_data, nbs_data); + ggml_backend_tensor_set(labels_batch, (const char *) labels->data + ishard*nbs_labels, ishard_batch*nbs_labels, nbs_labels); + } + } +}; + struct mnist_model { std::string arch; ggml_backend_t backend; @@ -116,17 +176,17 @@ struct mnist_eval_result { int64_t ntotal = 0; }; -bool mnist_image_load(const std::string & fname, float * buf, const int nex); -void mnist_image_print(FILE * f, const float * image); -bool mnist_label_load(const std::string & fname, float * buf, const int nex); +bool mnist_image_load(const std::string & fname, mnist_dataset & dataset); +void mnist_image_print(FILE * f, mnist_dataset & dataset, const int iex); +bool mnist_label_load(const std::string & fname, mnist_dataset & dataset); mnist_eval_result mnist_graph_eval(const std::string & fname, const float * images, const float * labels, const int nex, const int nthreads); mnist_model mnist_model_init_from_file(const std::string & fname, const std::string & backend); mnist_model mnist_model_init_random(const std::string & arch, const std::string & backend); void mnist_model_build(mnist_model & model, const int nbatch_logical, const int nbatch_physical); -mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, const float * labels, const int nex); -void mnist_model_train(mnist_model & model, const float * images, const float * labels, const int nex, const int nepoch, const float val_split); +mnist_eval_result mnist_model_eval(mnist_model & model, mnist_dataset & dataset); +void mnist_model_train(mnist_model & model, mnist_dataset & dataset, const int nepoch, const float val_split); void mnist_model_save(mnist_model & model, const std::string & fname); std::pair mnist_loss(const mnist_eval_result & result); diff --git a/examples/mnist/mnist-eval.cpp b/examples/mnist/mnist-eval.cpp index a5f9676f..125bbb8c 100644 --- a/examples/mnist/mnist-eval.cpp +++ b/examples/mnist/mnist-eval.cpp @@ -24,22 +24,17 @@ int main(int argc, char ** argv) { exit(1); } - std::vector images; - images.resize(MNIST_NTEST*MNIST_NINPUT); - if (!mnist_image_load(argv[2], images.data(), MNIST_NTEST)) { + struct mnist_dataset dataset(/*nex =*/ MNIST_NTEST, /*shard_size =*/ MNIST_NBATCH_PHYSICAL); + + if (!mnist_image_load(argv[2], dataset)) { return 1; } - - std::vector labels; - labels.resize(MNIST_NTEST*MNIST_NCLASSES); - if (!mnist_label_load(argv[3], labels.data(), MNIST_NTEST)) { + if (!mnist_label_load(argv[3], dataset)) { return 1; } const int iex = rand() % MNIST_NTEST; - const std::vector digit(images.begin() + iex*MNIST_NINPUT, images.begin() + (iex+1)*MNIST_NINPUT); - - mnist_image_print(stdout, images.data() + iex*MNIST_NINPUT); + mnist_image_print(stdout, dataset, iex); const std::string backend = argc >= 5 ? argv[4] : "CPU"; @@ -48,7 +43,7 @@ int main(int argc, char ** argv) { if (backend == "CPU") { const int ncores_logical = std::thread::hardware_concurrency(); result_eval = mnist_graph_eval( - argv[1], images.data(), labels.data(), MNIST_NTEST, std::min(ncores_logical, (ncores_logical + 4)/2)); + argv[1], ggml_get_data_f32(dataset.data), ggml_get_data_f32(dataset.labels), MNIST_NTEST, std::min(ncores_logical, (ncores_logical + 4)/2)); if (result_eval.success) { fprintf(stdout, "%s: predicted digit is %d\n", __func__, result_eval.pred[iex]); @@ -73,7 +68,7 @@ int main(int argc, char ** argv) { const int64_t t_load_us = ggml_time_us() - t_start_us; fprintf(stdout, "%s: loaded model in %.2lf ms\n", __func__, t_load_us / 1000.0); - result_eval = mnist_model_eval(model, images.data(), labels.data(), MNIST_NTEST); + result_eval = mnist_model_eval(model, dataset); fprintf(stdout, "%s: predicted digit is %d\n", __func__, result_eval.pred[iex]); std::pair result_loss = mnist_loss(result_eval); diff --git a/examples/mnist/mnist-train.cpp b/examples/mnist/mnist-train.cpp index 41a16221..161dcf80 100644 --- a/examples/mnist/mnist-train.cpp +++ b/examples/mnist/mnist-train.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -17,15 +16,15 @@ int main(int argc, char ** argv) { exit(0); } - std::vector images; - images.resize(MNIST_NTRAIN*MNIST_NINPUT); - if (!mnist_image_load(argv[3], images.data(), MNIST_NTRAIN)) { + // The MNIST model is so small that the overhead from data shuffling is non-negligible, especially with CUDA. + // With a shard size of 10 this overhead is greatly reduced at the cost of less shuffling (does not seem to have a significant impact). + // A batch of 500 images then consists of 50 random shards of size 10 instead of 500 random shards of size 1. + struct mnist_dataset dataset(/*nex =*/ MNIST_NTRAIN, /*shard_size =*/ 10); + + if (!mnist_image_load(argv[3], dataset)) { return 1; } - - std::vector labels; - labels.resize(MNIST_NTRAIN*MNIST_NCLASSES); - if (!mnist_label_load(argv[4], labels.data(), MNIST_NTRAIN)) { + if (!mnist_label_load(argv[4], dataset)) { return 1; } @@ -33,7 +32,7 @@ int main(int argc, char ** argv) { mnist_model_build(model, MNIST_NBATCH_LOGICAL, MNIST_NBATCH_PHYSICAL); - mnist_model_train(model, images.data(), labels.data(), MNIST_NTRAIN, /*nepoch =*/ 30, /*val_split =*/ 0.05f); + mnist_model_train(model, dataset, /*nepoch =*/ 30, /*val_split =*/ 0.05f); mnist_model_save(model, argv[2]); }