]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
examples: add dataset, data shuffling to MNIST (#982)
authorJohannes Gäßler <redacted>
Sat, 5 Oct 2024 16:38:01 +0000 (18:38 +0200)
committerGitHub <redacted>
Sat, 5 Oct 2024 16:38:01 +0000 (18:38 +0200)
examples/mnist/mnist-common.cpp
examples/mnist/mnist-common.h
examples/mnist/mnist-eval.cpp
examples/mnist/mnist-train.cpp

index c70dfd7fdc0e57d1660ed2d31ae63aa1ab45b70b..0c4ca07d01297d654aafd7cf98c9fcf60ae724ca 100644 (file)
@@ -14,7 +14,7 @@
 #include <string>
 #include <utility>
 
-bool mnist_image_load(const std::string & fname, float * buf, const int nex) {
+bool mnist_image_load(const std::string & fname, mnist_dataset & dataset) {
     auto fin = std::ifstream(fname, std::ios::binary);
     if (!fin) {
         fprintf(stderr, "failed to open images file %s\n", fname.c_str());
@@ -23,8 +23,9 @@ bool mnist_image_load(const std::string & fname, float * buf, const int nex) {
     fin.seekg(16);
 
     uint8_t image[MNIST_NINPUT];
+    float * buf = ggml_get_data_f32(dataset.data);
 
-    for (int iex = 0; iex < nex; ++iex) {
+    for (int iex = 0; iex < dataset.nex; ++iex) {
         fin.read((char *) image, sizeof(image));
 
         for (int i = 0; i < MNIST_NINPUT; ++i) {
@@ -35,12 +36,12 @@ bool mnist_image_load(const std::string & fname, float * buf, const int nex) {
     return true;
 }
 
-void mnist_image_print(FILE * stream, const float * image) {
-    static_assert(MNIST_NINPUT == 28*28, "Unexpected MNIST_NINPUT");
+void mnist_image_print(FILE * stream, mnist_dataset & dataset, const int iex) {
+    const float * image = ggml_get_data_f32(dataset.data) + iex*MNIST_NINPUT;
 
-    for (int row = 0; row < 28; row++) {
-        for (int col = 0; col < 28; col++) {
-            const int rgb = roundf(255.0f * image[row*28 + col]);
+    for (int row = 0; row < MNIST_HW; row++) {
+        for (int col = 0; col < MNIST_HW; col++) {
+            const int rgb = roundf(255.0f * image[row*MNIST_HW + col]);
 #ifdef _WIN32
             fprintf(stream, "%s", rgb >= 220 ? "##" : "__");                // Represented via text.
 #else
@@ -51,7 +52,7 @@ void mnist_image_print(FILE * stream, const float * image) {
     }
 }
 
-bool mnist_label_load(const std::string & fname, float * buf, const int nex) {
+bool mnist_label_load(const std::string & fname, mnist_dataset & dataset) {
     auto fin = std::ifstream(fname, std::ios::binary);
     if (!fin) {
         fprintf(stderr, "failed to open labels file %s\n", fname.c_str());
@@ -60,8 +61,9 @@ bool mnist_label_load(const std::string & fname, float * buf, const int nex) {
     fin.seekg(8);
 
     uint8_t label;
+    float * buf = ggml_get_data_f32(dataset.labels);
 
-    for (int iex = 0; iex < nex; ++iex) {
+    for (int iex = 0; iex < dataset.nex; ++iex) {
         fin.read((char *) &label, sizeof(label));
 
         for (int i = 0; i < MNIST_NCLASSES; ++i) {
@@ -500,7 +502,7 @@ void mnist_model_build(mnist_model & model, const int nbatch_logical, const int
     GGML_ASSERT(model.acc_count->ne[3] == 1);
 }
 
-mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, const float * labels, const int nex) {
+mnist_eval_result mnist_model_eval(mnist_model & model, mnist_dataset & dataset) {
     mnist_eval_result result;
 
     struct ggml_cgraph * gf = ggml_new_graph(model.ctx_compute);
@@ -522,10 +524,10 @@ mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, co
         GGML_ASSERT(sizeof(tmp_pred[0])*tmp_pred.size() == ggml_nbytes(model.pred));
         GGML_ASSERT(sizeof(tmp_acc_count)               == ggml_nbytes(model.acc_count));
 
-        GGML_ASSERT(nex % model.nbatch_physical == 0);
-        for (int iex0 = 0; iex0 < nex; iex0 += model.nbatch_physical) {
-            ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT,   0, ggml_nbytes(model.images));
-            ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels));
+        GGML_ASSERT(dataset.nex % model.nbatch_physical == 0);
+        const int nbatches = dataset.nex/model.nbatch_physical;
+        for (int ibatch = 0; ibatch < nbatches; ++ibatch) {
+            dataset.get_batch(model.images, model.labels, ibatch);
 
             ggml_backend_graph_compute(model.backend, gf);
 
@@ -542,17 +544,21 @@ mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, co
         const int64_t t_total_us = ggml_time_us() - t_start_us;
         const double t_total_ms = 1e-3*t_total_us;
         fprintf(stderr, "%s: model evaluation on %d images took %.2lf ms, %.2lf us/image\n",
-                __func__, nex, t_total_ms, (double) t_total_us/nex);
+                __func__, (int)dataset.nex, t_total_ms, (double) t_total_us/dataset.nex);
     }
 
     result.success = true;
     return result;
 }
 
-void mnist_model_train(mnist_model & model, const float * images, const float * labels, const int nex, const int nepoch, const float val_split) {
+void mnist_model_train(mnist_model & model, mnist_dataset & dataset, const int nepoch, const float val_split) {
     const int64_t t_start_us = ggml_time_us();
 
-    const bool accumulate = model.nbatch_physical != model.nbatch_logical;
+    const int opt_period        = model.nbatch_logical / model.nbatch_physical;
+    const int nbatches_logical  = dataset.nex / model.nbatch_logical;
+    const int nbatches_physical = dataset.nex / model.nbatch_physical;
+    const int ibatch_split      = ((int)((1.0f - val_split)*nbatches_logical))*opt_period; // train <-> val split index (physical)
+    const int ishard_split      = ibatch_split * model.nbatch_physical/dataset.shard_size;
 
     // gf == graph forward, forward pass only.
     struct ggml_cgraph * gf = ggml_new_graph_custom(model.ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
@@ -563,7 +569,7 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
 
     // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
     struct ggml_cgraph * gb_grad = ggml_graph_dup(model.ctx_compute, gf);
-    ggml_build_backward_expand(model.ctx_compute, gf, gb_grad, accumulate);
+    ggml_build_backward_expand(model.ctx_compute, gf, gb_grad, /*accumulate =*/ opt_period > 1);
 
     // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
     struct ggml_cgraph * gb_opt = ggml_graph_dup(model.ctx_compute, gb_grad);
@@ -572,13 +578,15 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
     model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend);
     ggml_graph_reset(gb_opt); // Set gradients to zero, reset optimizer.
 
-    const int iex_split = ((int)((1.0f - val_split)*nex) / model.nbatch_logical) * model.nbatch_logical;
+    dataset.shuffle(-1); // Shuffle all data (train + validation).
 
     for (int epoch = 0; epoch < nepoch; ++epoch) {
         fprintf(stderr, "%s: epoch %02d start...", __func__, epoch);
         const int64_t t_start_us = ggml_time_us();
 
-        int iex0 = 0;
+        dataset.shuffle(ishard_split); // Shuffle only the training data, keeping training and validation set separate.
+
+        int ibatch_physical = 0;
 
         float                tmp_loss;
         std::vector<int32_t> tmp_pred(model.nbatch_physical);
@@ -589,13 +597,12 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
         GGML_ASSERT(sizeof(tmp_acc_count)               == ggml_nbytes(model.acc_count));
 
         mnist_eval_result result_train;
-        for (; iex0 < iex_split; iex0 += model.nbatch_physical) {
-            ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT,   0, ggml_nbytes(model.images));
-            ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels));
+        for (; ibatch_physical < ibatch_split; ++ibatch_physical) {
+            dataset.get_batch(model.images, model.labels, ibatch_physical);
 
-            // With a period of nbatch_logical/nbatch_physical iterations:
-            if ((iex0 + model.nbatch_physical) % model.nbatch_logical != 0) {
-                // For the first nbatch_logical/nbatch_physical - 1 iterations, only calculate gradients and accumulate them:
+            // With a period of opt_period == nbatch_logical/nbatch_physical iterations:
+            if ((ibatch_physical + 1) % opt_period != 0) {
+                // For the first opt_period - 1 iterations, only calculate gradients and accumulate them:
                 ggml_backend_graph_compute(model.backend, gb_grad);
             } else {
                 // For the last iteration, calculate gradients and also apply the optimizer:
@@ -614,9 +621,8 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
         }
 
         mnist_eval_result result_val;
-        for (; iex0 < nex; iex0 += model.nbatch_physical) {
-            ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT,   0, ggml_nbytes(model.images));
-            ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels));
+        for (; ibatch_physical < nbatches_physical; ++ibatch_physical) {
+            dataset.get_batch(model.images, model.labels, ibatch_physical);
 
             ggml_backend_graph_compute(model.backend, gf); // For the validation set, only the forward pass is needed.
 
@@ -639,7 +645,7 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
             fprintf(stderr, "done, took %.2lfs, train_loss=%.6lf, train_acc=%.2f%%", t_epoch_s, loss_mean, percent_correct);
         }
 
-        if (iex_split < nex) {
+        if (ibatch_split < nbatches_physical) {
             const std::pair<double, double> loss = mnist_loss(result_val);
             const std::pair<double, double> acc  = mnist_accuracy(result_val);
 
@@ -731,11 +737,14 @@ extern "C" {
 
 int wasm_eval(uint8_t * digitPtr) {
     std::vector<float> digit(digitPtr, digitPtr + MNIST_NINPUT);
-    std::vector<float> labels(MNIST_NCLASSES);
+
+    struct mnist_dataset dataset(1, 1);
+    memcpy(dataset.data->data, digitPtr, ggml_nbytes(dataset.data));
+    ggml_set_zero(dataset.labels); // The labels are not needed.
 
     mnist_model model = mnist_model_init_from_file("mnist-f32.gguf", "CPU");
     mnist_model_build(model, 1, 1);
-    mnist_eval_result result = mnist_model_eval(model, digit.data(), labels.data(), 1);
+    mnist_eval_result result = mnist_model_eval(model, dataset);
 
     return result.pred[0];
 }
index 1c6fe783b25cb7a8e4f8a2b5a070fdbeea59b578..f0d248047fde9b2c0393edb289e6bac6e00ab96f 100644 (file)
@@ -1,4 +1,6 @@
+#include <algorithm>
 #include <cstdint>
+#include <random>
 #include <string>
 #include <thread>
 #include <vector>
@@ -25,6 +27,64 @@ static_assert(MNIST_NTEST  % MNIST_NBATCH_LOGICAL == 0, "MNIST_NTRAIN % MNIST_NB
 // NCB = number of channels base
 #define MNIST_CNN_NCB 8
 
+struct mnist_dataset {
+    struct ggml_context * ctx;
+    struct ggml_tensor  * data;
+    struct ggml_tensor  * labels;
+
+    int64_t nex;
+    int64_t shard_size;
+    size_t  nbs_data;
+    size_t  nbs_labels;
+
+    std::vector<int64_t> permutation;
+    std::mt19937 rng;
+
+    mnist_dataset(const int64_t nex, const int64_t shard_size) : nex(nex), shard_size(shard_size) {
+        const size_t nbytes_images = nex*MNIST_NINPUT  *sizeof(float) + ggml_tensor_overhead();
+        const size_t nbytes_labels = nex*MNIST_NCLASSES*sizeof(float) + ggml_tensor_overhead();
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ nbytes_images + nbytes_labels,
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ false,
+        };
+        ctx = ggml_init(params);
+
+        data   = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, MNIST_HW, MNIST_HW, nex);
+        labels = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, MNIST_NCLASSES,     nex);
+
+        nbs_data   = ggml_nbytes(data)   * shard_size/nex;
+        nbs_labels = ggml_nbytes(labels) * shard_size/nex;
+
+        permutation.resize(nex/shard_size);
+        for (size_t i = 0; i < permutation.size(); ++i) {
+            permutation[i] = i;
+        }
+    }
+
+    ~mnist_dataset() {
+        ggml_free(ctx);
+    }
+
+    void shuffle(const size_t ishard_max) {
+        if (ishard_max < permutation.size()) {
+            std::shuffle(permutation.begin(), permutation.begin() + ishard_max, rng);
+            return;
+        }
+        std::shuffle(permutation.begin(), permutation.end(), rng);
+    }
+
+    void get_batch(struct ggml_tensor * data_batch, struct ggml_tensor * labels_batch, const int64_t ibatch) {
+        const int64_t shards_per_batch = ggml_nbytes(data_batch) / nbs_data;
+        for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
+            const int64_t ishard = permutation[ibatch*shards_per_batch + ishard_batch];
+
+            ggml_backend_tensor_set(data_batch,   (const char *)   data->data + ishard*nbs_data,   ishard_batch*nbs_data,   nbs_data);
+            ggml_backend_tensor_set(labels_batch, (const char *) labels->data + ishard*nbs_labels, ishard_batch*nbs_labels, nbs_labels);
+        }
+    }
+};
+
 struct mnist_model {
     std::string arch;
     ggml_backend_t backend;
@@ -116,17 +176,17 @@ struct mnist_eval_result {
     int64_t              ntotal   = 0;
 };
 
-bool mnist_image_load(const std::string & fname, float * buf, const int nex);
-void mnist_image_print(FILE * f, const float * image);
-bool mnist_label_load(const std::string & fname, float * buf, const int nex);
+bool mnist_image_load(const std::string & fname, mnist_dataset & dataset);
+void mnist_image_print(FILE * f, mnist_dataset & dataset, const int iex);
+bool mnist_label_load(const std::string & fname, mnist_dataset & dataset);
 
 mnist_eval_result mnist_graph_eval(const std::string & fname, const float * images, const float * labels, const int nex, const int nthreads);
 
 mnist_model       mnist_model_init_from_file(const std::string & fname, const std::string & backend);
 mnist_model       mnist_model_init_random(const std::string & arch, const std::string & backend);
 void              mnist_model_build(mnist_model & model, const int nbatch_logical, const int nbatch_physical);
-mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, const float * labels, const int nex);
-void              mnist_model_train(mnist_model & model, const float * images, const float * labels, const int nex, const int nepoch, const float val_split);
+mnist_eval_result mnist_model_eval(mnist_model & model, mnist_dataset & dataset);
+void              mnist_model_train(mnist_model & model, mnist_dataset & dataset, const int nepoch, const float val_split);
 void              mnist_model_save(mnist_model & model, const std::string & fname);
 
 std::pair<double, double> mnist_loss(const mnist_eval_result & result);
index a5f9676faf085185637de3710488c41bc3249144..125bbb8cb154fbb4fccf08e5a24209ad2fa6cb61 100644 (file)
@@ -24,22 +24,17 @@ int main(int argc, char ** argv) {
         exit(1);
     }
 
-    std::vector<float> images;
-    images.resize(MNIST_NTEST*MNIST_NINPUT);
-    if (!mnist_image_load(argv[2], images.data(), MNIST_NTEST)) {
+    struct mnist_dataset dataset(/*nex =*/ MNIST_NTEST, /*shard_size =*/ MNIST_NBATCH_PHYSICAL);
+
+    if (!mnist_image_load(argv[2], dataset)) {
         return 1;
     }
-
-    std::vector<float> labels;
-    labels.resize(MNIST_NTEST*MNIST_NCLASSES);
-    if (!mnist_label_load(argv[3], labels.data(), MNIST_NTEST)) {
+    if (!mnist_label_load(argv[3], dataset)) {
         return 1;
     }
 
     const int iex = rand() % MNIST_NTEST;
-    const std::vector<float> digit(images.begin() + iex*MNIST_NINPUT, images.begin() + (iex+1)*MNIST_NINPUT);
-
-    mnist_image_print(stdout, images.data() + iex*MNIST_NINPUT);
+    mnist_image_print(stdout, dataset, iex);
 
     const std::string backend = argc >= 5 ? argv[4] : "CPU";
 
@@ -48,7 +43,7 @@ int main(int argc, char ** argv) {
     if (backend == "CPU") {
         const int ncores_logical = std::thread::hardware_concurrency();
         result_eval = mnist_graph_eval(
-            argv[1], images.data(), labels.data(), MNIST_NTEST, std::min(ncores_logical, (ncores_logical + 4)/2));
+            argv[1], ggml_get_data_f32(dataset.data), ggml_get_data_f32(dataset.labels), MNIST_NTEST, std::min(ncores_logical, (ncores_logical + 4)/2));
         if (result_eval.success) {
             fprintf(stdout, "%s: predicted digit is %d\n", __func__, result_eval.pred[iex]);
 
@@ -73,7 +68,7 @@ int main(int argc, char ** argv) {
     const int64_t t_load_us = ggml_time_us() - t_start_us;
 
     fprintf(stdout, "%s: loaded model in %.2lf ms\n", __func__, t_load_us / 1000.0);
-    result_eval = mnist_model_eval(model, images.data(), labels.data(), MNIST_NTEST);
+    result_eval = mnist_model_eval(model, dataset);
     fprintf(stdout, "%s: predicted digit is %d\n", __func__, result_eval.pred[iex]);
 
     std::pair<double, double> result_loss = mnist_loss(result_eval);
index 41a16221a8ebc773eb8da608182de177d0988247..161dcf80fb864d4ca0eb0a729c34de1c42eb8498 100644 (file)
@@ -5,7 +5,6 @@
 #include <cstring>
 #include <ctime>
 #include <string>
-#include <thread>
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
@@ -17,15 +16,15 @@ int main(int argc, char ** argv) {
         exit(0);
     }
 
-    std::vector<float> images;
-    images.resize(MNIST_NTRAIN*MNIST_NINPUT);
-    if (!mnist_image_load(argv[3], images.data(), MNIST_NTRAIN)) {
+    // The MNIST model is so small that the overhead from data shuffling is non-negligible, especially with CUDA.
+    // With a shard size of 10 this overhead is greatly reduced at the cost of less shuffling (does not seem to have a significant impact).
+    // A batch of 500 images then consists of 50 random shards of size 10 instead of 500 random shards of size 1.
+    struct mnist_dataset dataset(/*nex =*/ MNIST_NTRAIN, /*shard_size =*/ 10);
+
+    if (!mnist_image_load(argv[3], dataset)) {
         return 1;
     }
-
-    std::vector<float> labels;
-    labels.resize(MNIST_NTRAIN*MNIST_NCLASSES);
-    if (!mnist_label_load(argv[4], labels.data(), MNIST_NTRAIN)) {
+    if (!mnist_label_load(argv[4], dataset)) {
         return 1;
     }
 
@@ -33,7 +32,7 @@ int main(int argc, char ** argv) {
 
     mnist_model_build(model, MNIST_NBATCH_LOGICAL, MNIST_NBATCH_PHYSICAL);
 
-    mnist_model_train(model, images.data(), labels.data(), MNIST_NTRAIN, /*nepoch =*/ 30, /*val_split =*/ 0.05f);
+    mnist_model_train(model, dataset, /*nepoch =*/ 30, /*val_split =*/ 0.05f);
 
     mnist_model_save(model, argv[2]);
 }