From: Ray Cromwell Date: Thu, 13 Apr 2023 20:49:45 +0000 (-0700) Subject: examples : MNIST example for ggml (#84) X-Git-Tag: upstream/0.0.1642~1549 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=73999158eee3d493585fad2262a6675e7de90f8b;p=pkg%2Fggml%2Fsources%2Fggml examples : MNIST example for ggml (#84) --- diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 5f7f3a48..d650d92e 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -4,3 +4,4 @@ target_include_directories(ggml_utils PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) add_subdirectory(gpt-2) add_subdirectory(gpt-j) add_subdirectory(whisper) +add_subdirectory(mnist) diff --git a/examples/mnist/CMakeLists.txt b/examples/mnist/CMakeLists.txt new file mode 100644 index 00000000..8df55a36 --- /dev/null +++ b/examples/mnist/CMakeLists.txt @@ -0,0 +1,7 @@ +# +# mnist + +set(TEST_TARGET mnist) +add_executable(${TEST_TARGET} main.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml ggml_utils) + diff --git a/examples/mnist/README.md b/examples/mnist/README.md new file mode 100644 index 00000000..d0ab2dba --- /dev/null +++ b/examples/mnist/README.md @@ -0,0 +1,47 @@ +# MNIST Example for GGML + +This is a simple example of how to use GGML for inferencing. + +## Training the Model + +A Google Colab notebook for training a simple two-layer network to recognize digits is located here. You can +use this to save a pytorch model to be converted to ggml format. + +[Colab](https://colab.research.google.com/drive/12n_8VNJnolBnX5dVS0HNWubnOjyEaFSb?usp=sharing) + + +## GGML Format Conversion + +GGML "format" is whatever you choose for efficient loading. In our case, we just save the hyperparameters used +plus the model weights and biases. Run convert-h5-to-ggml.py to convert your pytorch model. The output format is: + +- magic constant (int32) +- repeated list of tensors +-- number of dimensions of tensor (int32) +-- tensor dimension (int32 repeated) +-- values of tensor (int32) + +Run ```convert-h5-to-ggml.py mnist_model.state_dict``` where `mnist_model.state_dict` is the saved pytorch model from the Google Colab. For +quickstart, it is included in the mnist/models directory. + +## MNIST Network + +The MNIST recognizer network is extremely simple. A fully connected layer + relu, followed by a fully connected layer + softmax. This +version of the MNIST network doesn't use convolutions. + +## Running the example + +Here is how to run the example programs: + +```bash +# Build ggml + examples +git clone https://github.com/ggerganov/ggml +cd ggml +mkdir build && cd build +cmake .. +make -j4 mnist + +# Run the MNIST model +./bin/mnist ../examples/mnist/models/mnist/ggml-model-f32.bin ../examples/mnist/models/mnist/t10k-images.idx3-ubyte + +For more information, checkout the corresponding programs in the [examples](examples) folder. diff --git a/examples/mnist/convert-h5-to-ggml.py b/examples/mnist/convert-h5-to-ggml.py new file mode 100644 index 00000000..a4f75365 --- /dev/null +++ b/examples/mnist/convert-h5-to-ggml.py @@ -0,0 +1,63 @@ +# Convert MNIS h5 transformer model to ggml format +# +# Load the (state_dict) saved model using PyTorch +# Iterate over all variables and write them to a binary file. +# +# For each variable, write the following: +# - Number of dimensions (int) +# - Name length (int) +# - Dimensions (int[n_dims]) +# - Name (char[name_length]) +# - Data (float[n_dims]) +# +# At the start of the ggml file we write the model parameters + +import sys +import struct +import json +import numpy as np +import re + + +import torch +import torch.nn as nn +import torchvision.datasets as dsets +import torchvision.transforms as transforms +from torch.autograd import Variable + +if len(sys.argv) != 2: + print("Usage: convert-h5-to-ggml.py model\n") + sys.exit(1) + +state_dict_file = sys.argv[1] +fname_out = "models/mnist/ggml-model-f32.bin" + +state_dict = torch.load(state_dict_file, map_location=torch.device('cpu')) +#print (model) + +list_vars = state_dict +print (list_vars) + +fout = open(fname_out, "wb") + +fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex + + +for name in list_vars.keys(): + data = list_vars[name].squeeze().numpy() + print("Processing variable: " + name + " with shape: ", data.shape) + n_dims = len(data.shape); + + fout.write(struct.pack("i", n_dims)) + + data = data.astype(np.float32) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + + # data + data.tofile(fout) + +fout.close() + +print("Done. Output file: " + fname_out) +print("") diff --git a/examples/mnist/main.cpp b/examples/mnist/main.cpp new file mode 100644 index 00000000..3414f24a --- /dev/null +++ b/examples/mnist/main.cpp @@ -0,0 +1,253 @@ +#include "ggml/ggml.h" + +#include "utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// default hparams +struct mnist_hparams { + int32_t n_input = 784; + int32_t n_hidden = 500; + int32_t n_classes = 10; +}; + + +struct mnist_model { + mnist_hparams hparams; + + struct ggml_tensor * fc1_weight; + struct ggml_tensor * fc1_bias; + + struct ggml_tensor * fc2_weight; + struct ggml_tensor * fc2_bias; + + struct ggml_context * ctx; +}; + +// load the model's weights from a file +bool mnist_model_load(const std::string & fname, mnist_model & model) { + printf("%s: loading model from '%s'\n", __func__, fname.c_str()); + + auto fin = std::ifstream(fname, std::ios::binary); + if (!fin) { + fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); + return false; + } + + // verify magic + { + uint32_t magic; + fin.read((char *) &magic, sizeof(magic)); + if (magic != 0x67676d6c) { + fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); + return false; + } + } + + const ggml_type wtype2 = GGML_TYPE_F32; + + auto & ctx = model.ctx; + size_t ctx_size = 0; + + { + const auto & hparams = model.hparams; + + const int n_input = hparams.n_input; + const int n_hidden = hparams.n_hidden; + const int n_classes = hparams.n_classes; + + // fc1 weight + ctx_size += n_input * n_hidden * ggml_type_sizef(GGML_TYPE_F32); + // fc1 bias + ctx_size += n_hidden * ggml_type_sizef(GGML_TYPE_F32); + + // fc2 weight + ctx_size += n_hidden * n_classes * ggml_type_sizef(GGML_TYPE_F32); + // fc2 bias + ctx_size += n_classes * ggml_type_sizef(GGML_TYPE_F32); + + printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); + } + + // create the ggml context + { + struct ggml_init_params params = { + .mem_size = ctx_size + 1024*1024, + .mem_buffer = NULL, + }; + + model.ctx = ggml_init(params); + if (!model.ctx) { + fprintf(stderr, "%s: ggml_init() failed\n", __func__); + return false; + } + } + + // Read FC1 layer 1 + { + // Read dimensions + int32_t n_dims; + fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + + int32_t ne_weight[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + fin.read(reinterpret_cast(&ne_weight[i]), sizeof(ne_weight[i])); + } + + // FC1 dimensions taken from file, eg. 768x500 + model.hparams.n_input = ne_weight[0]; + model.hparams.n_hidden = ne_weight[1]; + + model.fc1_weight = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, model.hparams.n_input, model.hparams.n_hidden); + fin.read(reinterpret_cast(model.fc1_weight->data), ggml_nbytes(model.fc1_weight)); + + int32_t ne_bias[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + fin.read(reinterpret_cast(&ne_bias[i]), sizeof(ne_bias[i])); + } + + model.fc1_bias = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_hidden); + fin.read(reinterpret_cast(model.fc1_bias->data), ggml_nbytes(model.fc1_bias)); + } + + // Read FC2 layer 2 + { + // Read dimensions + int32_t n_dims; + fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + + int32_t ne_weight[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + fin.read(reinterpret_cast(&ne_weight[i]), sizeof(ne_weight[i])); + } + + // FC1 dimensions taken from file, eg. 10x500 + model.hparams.n_classes = ne_weight[1]; + + model.fc2_weight = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, model.hparams.n_hidden, model.hparams.n_classes); + fin.read(reinterpret_cast(model.fc2_weight->data), ggml_nbytes(model.fc2_weight)); + + int32_t ne_bias[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + fin.read(reinterpret_cast(&ne_bias[i]), sizeof(ne_bias[i])); + } + model.fc2_bias = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_classes); + fin.read(reinterpret_cast(model.fc2_bias->data), ggml_nbytes(model.fc2_bias)); + } + fin.close(); + + return true; +} + +// evaluate the model +// +// - model: the model +// - n_threads: number of threads to use +// - digit: 784 pixel values +// returns 0 - 9 prediction +int mnist_eval( + const mnist_model & model, + const int n_threads, + std::vector digit + ) { + + const auto & hparams = model.hparams; + + static size_t buf_size = hparams.n_input * sizeof(float) * 4; + static void * buf = malloc(buf_size); + + struct ggml_init_params params = { + .mem_size = buf_size, + .mem_buffer = buf, + }; + + struct ggml_context * ctx0 = ggml_init(params); + struct ggml_cgraph gf = { .n_threads = n_threads }; + + struct ggml_tensor * input = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_input); + memcpy(input->data, digit.data(), ggml_nbytes(input)); + + // fc1 MLP = Ax + b + ggml_tensor * fc1 = ggml_add(ctx0, ggml_mul_mat(ctx0, model.fc1_weight, input), model.fc1_bias); + ggml_tensor * fc2 = ggml_add(ctx0, ggml_mul_mat(ctx0, model.fc2_weight, ggml_relu(ctx0, fc1)), model.fc2_bias); + + // soft max + ggml_tensor * final = ggml_soft_max(ctx0, fc2); + + // run the computation + ggml_build_forward_expand(&gf, final); + ggml_graph_compute (ctx0, &gf); + + ggml_graph_print (&gf); + ggml_graph_dump_dot(&gf, NULL, "mnist.dot"); + float* finalData = ggml_get_data_f32(final); + + int prediction = std::max_element(finalData, finalData + 10) - finalData; + ggml_free(ctx0); + return prediction; +} + +int main(int argc, char ** argv) { + + if (argc != 3) { + fprintf(stderr, "Usage: %s models/mnist/ggml-model-f32.bin models/mnist/t10k-images.idx3-ubyte\n", argv[0]); + exit(0); + } + const int64_t t_main_start_us = ggml_time_us(); + + mnist_hparams params; + int64_t t_load_us = 0; + + mnist_model model; + std::vector digit; + // load the model, load a random test digit, evaluate the model + { + const int64_t t_start_us = ggml_time_us(); + + if (!mnist_model_load(argv[1], model)) { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, "models/ggml-model-f32.bin"); + return 1; + } + auto fin = std::ifstream(argv[2], std::ios::binary); + if (!fin) { + fprintf(stderr, "%s: failed to open '%s'\n", __func__, argv[2]); + return 1; + } + + unsigned char buf[784]; + srand(time(NULL)); + // Seek to a random digit: 16-byte header + 28*28 * (random 0 - 10000) + fin.seekg(16 + 784 * (rand() % 10000)); + fin.read((char *) &buf, sizeof(buf)); + digit.resize(sizeof(buf)); + + // render the digit in ASCII + for(int row = 0; row < 28; row++) { + for (int col = 0; col < 28; col++) { + fprintf(stderr, "%c ", (float)buf[row*28 + col] > 230 ? '*' : '_'); + digit[row*28+col]=((float)buf[row*28+col]); + + } + fprintf(stderr, "\n"); + } + fprintf(stderr, "\n"); + + t_load_us = ggml_time_us() - t_start_us; + } + + + fprintf(stdout, "Predicted digit is %d\n", mnist_eval(model, 1, digit)); + ggml_free(model.ctx); + + return 0; +} diff --git a/examples/mnist/models/mnist/mnist_model.state_dict b/examples/mnist/models/mnist/mnist_model.state_dict new file mode 100644 index 00000000..dfb609b8 Binary files /dev/null and b/examples/mnist/models/mnist/mnist_model.state_dict differ diff --git a/examples/mnist/models/mnist/t10k-images.idx3-ubyte b/examples/mnist/models/mnist/t10k-images.idx3-ubyte new file mode 100644 index 00000000..1170b2ca Binary files /dev/null and b/examples/mnist/models/mnist/t10k-images.idx3-ubyte differ