From: Radoslav Gerganov Date: Sun, 27 Aug 2023 17:46:51 +0000 (+0300) Subject: Add MNIST inference example with CNN X-Git-Tag: upstream/0.0.1642~1257^2 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=9a287f7a1b6c24c9e090dc782a9eca5757105f71;p=pkg%2Fggml%2Fsources%2Fggml Add MNIST inference example with CNN Add one more implementation for MNIST which uses Conv2D layers, ref: https://keras.io/examples/vision/mnist_convnet/. It achieves ~99% accuracy on the MNIST test set and also performs better for user inputs. This implementation expects a model in GGUF format. You can get one with the 'mnist-cnn.py' script. Example usage: $ ./mnist-cnn.py train mnist-cnn-model ... Keras model saved to 'mnist-cnn-model' $ ./mnist-cnn.py convert mnist-cnn-model ... Model converted and saved to 'mnist-cnn-model.gguf' $ ./mnist-cnn mnist-cnn-model.gguf models/mnist/t10k-images.idx3-ubyte --- diff --git a/examples/mnist/CMakeLists.txt b/examples/mnist/CMakeLists.txt index 3ce09249..4d9b93ed 100644 --- a/examples/mnist/CMakeLists.txt +++ b/examples/mnist/CMakeLists.txt @@ -5,6 +5,13 @@ set(TEST_TARGET mnist) add_executable(${TEST_TARGET} main.cpp) target_link_libraries(${TEST_TARGET} PRIVATE ggml common) +# +# mnist-cnn + +set(TEST_TARGET mnist-cnn) +add_executable(${TEST_TARGET} main-cnn.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml common) + # # mnist-cpu diff --git a/examples/mnist/main-cnn.cpp b/examples/mnist/main-cnn.cpp new file mode 100644 index 00000000..7949e9a6 --- /dev/null +++ b/examples/mnist/main-cnn.cpp @@ -0,0 +1,169 @@ +#include "ggml/ggml.h" + +#include "common.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +struct mnist_model { + struct ggml_tensor * conv2d_1_kernel; + struct ggml_tensor * conv2d_1_bias; + struct ggml_tensor * conv2d_2_kernel; + struct ggml_tensor * conv2d_2_bias; + struct ggml_tensor * dense_weight; + struct ggml_tensor * dense_bias; + struct ggml_context * ctx; +}; + +bool mnist_model_load(const std::string & fname, mnist_model & model) { + struct gguf_init_params params = { + /*.no_alloc =*/ false, + /*.ctx =*/ &model.ctx, + }; + gguf_context * ctx = gguf_init_from_file(fname.c_str(), params); + if (!ctx) { + fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__); + return false; + } + model.conv2d_1_kernel = ggml_get_tensor(model.ctx, "kernel1"); + model.conv2d_1_bias = ggml_get_tensor(model.ctx, "bias1"); + model.conv2d_2_kernel = ggml_get_tensor(model.ctx, "kernel2"); + model.conv2d_2_bias = ggml_get_tensor(model.ctx, "bias2"); + model.dense_weight = ggml_get_tensor(model.ctx, "dense_w"); + model.dense_bias = ggml_get_tensor(model.ctx, "dense_b"); + return true; +} + +int mnist_eval( + const mnist_model & model, + const int n_threads, + std::vector digit, + const char * fname_cgraph + ) +{ + static size_t buf_size = 100000 * sizeof(float) * 4; + static void * buf = malloc(buf_size); + + struct ggml_init_params params = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf, + /*.no_alloc =*/ false, + }; + + struct ggml_context * ctx0 = ggml_init(params); + struct ggml_cgraph gf = {}; + + struct ggml_tensor * input = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 28, 28, 1, 1); + memcpy(input->data, digit.data(), ggml_nbytes(input)); + ggml_set_name(input, "input"); + ggml_tensor * cur = ggml_conv_2d(ctx0, model.conv2d_1_kernel, input, 1, 1, 0, 0, 1, 1); + cur = ggml_add(ctx0, cur, model.conv2d_1_bias); + cur = ggml_relu(ctx0, cur); + // Output shape after Conv2D: (26 26 32 1) + cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0); + // Output shape after MaxPooling2D: (13 13 32 1) + cur = ggml_conv_2d(ctx0, model.conv2d_2_kernel, cur, 1, 1, 0, 0, 1, 1); + cur = ggml_add(ctx0, cur, model.conv2d_2_bias); + cur = ggml_relu(ctx0, cur); + // Output shape after Conv2D: (11 11 64 1) + cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0); + // Output shape after MaxPooling2D: (5 5 64 1) + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3)); + // Output shape after permute: (64 5 5 1) + cur = ggml_reshape_2d(ctx0, cur, 1600, 1); + // Final Dense layer + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.dense_weight, cur), model.dense_bias); + ggml_tensor * probs = ggml_soft_max(ctx0, cur); + ggml_set_name(probs, "probs"); + + ggml_build_forward_expand(&gf, probs); + ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); + + //ggml_graph_print(&gf); + ggml_graph_dump_dot(&gf, NULL, "mnist-cnn.dot"); + + if (fname_cgraph) { + // export the compute graph for later use + // see the "mnist-cpu" example + ggml_graph_export(&gf, fname_cgraph); + + fprintf(stderr, "%s: exported compute graph to '%s'\n", __func__, fname_cgraph); + } + + const float * probs_data = ggml_get_data_f32(probs); + const int prediction = std::max_element(probs_data, probs_data + 10) - probs_data; + ggml_free(ctx0); + return prediction; +} + +int main(int argc, char ** argv) { + srand(time(NULL)); + ggml_time_init(); + + if (argc != 3) { + fprintf(stderr, "Usage: %s models/mnist/mnist-cnn.gguf models/mnist/t10k-images.idx3-ubyte\n", argv[0]); + exit(0); + } + + uint8_t buf[784]; + mnist_model model; + std::vector digit; + + // load the model + { + const int64_t t_start_us = ggml_time_us(); + + if (!mnist_model_load(argv[1], model)) { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, argv[1]); + return 1; + } + + const int64_t t_load_us = ggml_time_us() - t_start_us; + + fprintf(stdout, "%s: loaded model in %8.2f ms\n", __func__, t_load_us / 1000.0f); + } + + // read a random digit from the test set + { + std::ifstream fin(argv[2], std::ios::binary); + if (!fin) { + fprintf(stderr, "%s: failed to open '%s'\n", __func__, argv[2]); + return 1; + } + + // seek to a random digit: 16-byte header + 28*28 * (random 0 - 10000) + fin.seekg(16 + 784 * (rand() % 10000)); + fin.read((char *) &buf, sizeof(buf)); + } + + // render the digit in ASCII + { + digit.resize(sizeof(buf)); + + for (int row = 0; row < 28; row++) { + for (int col = 0; col < 28; col++) { + fprintf(stderr, "%c ", (float)buf[row*28 + col] > 230 ? '*' : '_'); + digit[row*28 + col] = ((float)buf[row*28 + col] / 255.0f); + } + + fprintf(stderr, "\n"); + } + + fprintf(stderr, "\n"); + } + + const int prediction = mnist_eval(model, 1, digit, nullptr); + fprintf(stdout, "%s: predicted digit is %d\n", __func__, prediction); + ggml_free(model.ctx); + return 0; +} diff --git a/examples/mnist/mnist-cnn.py b/examples/mnist/mnist-cnn.py new file mode 100755 index 00000000..35dda60a --- /dev/null +++ b/examples/mnist/mnist-cnn.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +import sys +import gguf +import numpy as np +from tensorflow import keras +from tensorflow.keras import layers + +def train(model_name): + # Model / data parameters + num_classes = 10 + input_shape = (28, 28, 1) + + # Load the data and split it between train and test sets + (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + + # Scale images to the [0, 1] range + x_train = x_train.astype("float32") / 255 + x_test = x_test.astype("float32") / 255 + # Make sure images have shape (28, 28, 1) + x_train = np.expand_dims(x_train, -1) + x_test = np.expand_dims(x_test, -1) + print("x_train shape:", x_train.shape) + print(x_train.shape[0], "train samples") + print(x_test.shape[0], "test samples") + + # convert class vectors to binary class matrices + y_train = keras.utils.to_categorical(y_train, num_classes) + y_test = keras.utils.to_categorical(y_test, num_classes) + + model = keras.Sequential( + [ + keras.Input(shape=input_shape), + layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Flatten(), + layers.Dropout(0.5), + layers.Dense(num_classes, activation="softmax"), + ] + ) + + model.summary() + batch_size = 128 + epochs = 15 + model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) + model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1) + + score = model.evaluate(x_test, y_test, verbose=0) + print("Test loss:", score[0]) + print("Test accuracy:", score[1]) + model.save(model_name) + print("Keras model saved to '" + model_name + "'") + +def convert(model_name): + model = keras.models.load_model(model_name) + gguf_model_name = model_name + ".gguf" + gguf_writer = gguf.GGUFWriter(gguf_model_name, "mnist-cnn") + + kernel1 = model.layers[0].weights[0].numpy() + kernel1 = np.moveaxis(kernel1, [2,3], [0,1]) + kernel1 = kernel1.astype(np.float16) + gguf_writer.add_tensor("kernel1", kernel1, raw_shape=(32, 1, 3, 3)) + + bias1 = model.layers[0].weights[1].numpy() + bias1 = np.repeat(bias1, 26*26) + gguf_writer.add_tensor("bias1", bias1, raw_shape=(1, 32, 26, 26)) + + kernel2 = model.layers[2].weights[0].numpy() + kernel2 = np.moveaxis(kernel2, [0,1,2,3], [2,3,1,0]) + kernel2 = kernel2.astype(np.float16) + gguf_writer.add_tensor("kernel2", kernel2, raw_shape=(64, 32, 3, 3)) + + bias2 = model.layers[2].weights[1].numpy() + bias2 = np.repeat(bias2, 11*11) + gguf_writer.add_tensor("bias2", bias2, raw_shape=(1, 64, 11, 11)) + + dense_w = model.layers[-1].weights[0].numpy() + dense_w = dense_w.transpose() + gguf_writer.add_tensor("dense_w", dense_w, raw_shape=(10, 1600)) + + dense_b = model.layers[-1].weights[1].numpy() + gguf_writer.add_tensor("dense_b", dense_b) + + gguf_writer.write_header_to_file() + gguf_writer.write_kv_data_to_file() + gguf_writer.write_tensors_to_file() + gguf_writer.close() + print("Model converted and saved to '{}'".format(gguf_model_name)) + +if __name__ == '__main__': + if len(sys.argv) < 3: + print("Usage: %s ".format(sys.argv[0])) + sys.exit(1) + if sys.argv[1] == 'train': + train(sys.argv[2]) + elif sys.argv[1] == 'convert': + convert(sys.argv[2]) + else: + print("Usage: %s ".format(sys.argv[0])) + sys.exit(1)