]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Add MNIST inference example with CNN
authorRadoslav Gerganov <redacted>
Sun, 27 Aug 2023 17:46:51 +0000 (20:46 +0300)
committerRadoslav Gerganov <redacted>
Mon, 28 Aug 2023 13:24:54 +0000 (16:24 +0300)
Add one more implementation for MNIST which uses Conv2D layers, ref:
https://keras.io/examples/vision/mnist_convnet/. It achieves ~99%
accuracy on the MNIST test set and also performs better for user inputs.
This implementation expects a model in GGUF format. You can get one with
the 'mnist-cnn.py' script. Example usage:

$ ./mnist-cnn.py train mnist-cnn-model
...
Keras model saved to 'mnist-cnn-model'

$ ./mnist-cnn.py convert mnist-cnn-model
...
Model converted and saved to 'mnist-cnn-model.gguf'

$ ./mnist-cnn mnist-cnn-model.gguf models/mnist/t10k-images.idx3-ubyte

examples/mnist/CMakeLists.txt
examples/mnist/main-cnn.cpp [new file with mode: 0644]
examples/mnist/mnist-cnn.py [new file with mode: 0755]

index 3ce0924902e3e48e56b4b5aab3890a23e2f1dba2..4d9b93edc36dd84997340df3c9f49cf83cbbe9bc 100644 (file)
@@ -5,6 +5,13 @@ set(TEST_TARGET mnist)
 add_executable(${TEST_TARGET} main.cpp)
 target_link_libraries(${TEST_TARGET} PRIVATE ggml common)
 
+#
+# mnist-cnn
+
+set(TEST_TARGET mnist-cnn)
+add_executable(${TEST_TARGET} main-cnn.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml common)
+
 #
 # mnist-cpu
 
diff --git a/examples/mnist/main-cnn.cpp b/examples/mnist/main-cnn.cpp
new file mode 100644 (file)
index 0000000..7949e9a
--- /dev/null
@@ -0,0 +1,169 @@
+#include "ggml/ggml.h"
+
+#include "common.h"
+
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <ctime>
+#include <fstream>
+#include <string>
+#include <vector>
+#include <algorithm>
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+struct mnist_model {
+    struct ggml_tensor * conv2d_1_kernel;
+    struct ggml_tensor * conv2d_1_bias;
+    struct ggml_tensor * conv2d_2_kernel;
+    struct ggml_tensor * conv2d_2_bias;
+    struct ggml_tensor * dense_weight;
+    struct ggml_tensor * dense_bias;
+    struct ggml_context * ctx;
+};
+
+bool mnist_model_load(const std::string & fname, mnist_model & model) {
+    struct gguf_init_params params = {
+        /*.no_alloc   =*/ false,
+        /*.ctx        =*/ &model.ctx,
+    };
+    gguf_context * ctx = gguf_init_from_file(fname.c_str(), params);
+    if (!ctx) {
+        fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__);
+        return false;
+    }
+    model.conv2d_1_kernel = ggml_get_tensor(model.ctx, "kernel1");
+    model.conv2d_1_bias = ggml_get_tensor(model.ctx, "bias1");
+    model.conv2d_2_kernel = ggml_get_tensor(model.ctx, "kernel2");
+    model.conv2d_2_bias = ggml_get_tensor(model.ctx, "bias2");
+    model.dense_weight = ggml_get_tensor(model.ctx, "dense_w");
+    model.dense_bias = ggml_get_tensor(model.ctx, "dense_b");
+    return true;
+}
+
+int mnist_eval(
+        const mnist_model & model,
+        const int n_threads,
+        std::vector<float> digit,
+        const char * fname_cgraph
+        )
+{
+    static size_t buf_size = 100000 * sizeof(float) * 4;
+    static void * buf = malloc(buf_size);
+
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ buf_size,
+        /*.mem_buffer =*/ buf,
+        /*.no_alloc   =*/ false,
+    };
+
+    struct ggml_context * ctx0 = ggml_init(params);
+    struct ggml_cgraph gf = {};
+
+    struct ggml_tensor * input = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 28, 28, 1, 1);
+    memcpy(input->data, digit.data(), ggml_nbytes(input));
+    ggml_set_name(input, "input");
+    ggml_tensor * cur = ggml_conv_2d(ctx0, model.conv2d_1_kernel, input, 1, 1, 0, 0, 1, 1);
+    cur = ggml_add(ctx0, cur, model.conv2d_1_bias);
+    cur = ggml_relu(ctx0, cur);
+    // Output shape after Conv2D: (26 26 32 1)
+    cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
+    // Output shape after MaxPooling2D: (13 13 32 1)
+    cur = ggml_conv_2d(ctx0, model.conv2d_2_kernel, cur, 1, 1, 0, 0, 1, 1);
+    cur = ggml_add(ctx0, cur, model.conv2d_2_bias);
+    cur = ggml_relu(ctx0, cur);
+    // Output shape after Conv2D: (11 11 64 1)
+    cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
+    // Output shape after MaxPooling2D: (5 5 64 1)
+    cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
+    // Output shape after permute: (64 5 5 1)
+    cur = ggml_reshape_2d(ctx0, cur, 1600, 1);
+    // Final Dense layer
+    cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.dense_weight, cur), model.dense_bias);
+    ggml_tensor * probs = ggml_soft_max(ctx0, cur);
+    ggml_set_name(probs, "probs");
+
+    ggml_build_forward_expand(&gf, probs);
+    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+
+    //ggml_graph_print(&gf);
+    ggml_graph_dump_dot(&gf, NULL, "mnist-cnn.dot");
+
+    if (fname_cgraph) {
+        // export the compute graph for later use
+        // see the "mnist-cpu" example
+        ggml_graph_export(&gf, fname_cgraph);
+
+        fprintf(stderr, "%s: exported compute graph to '%s'\n", __func__, fname_cgraph);
+    }
+
+    const float * probs_data = ggml_get_data_f32(probs);
+    const int prediction = std::max_element(probs_data, probs_data + 10) - probs_data;
+    ggml_free(ctx0);
+    return prediction;
+}
+
+int main(int argc, char ** argv) {
+    srand(time(NULL));
+    ggml_time_init();
+
+    if (argc != 3) {
+        fprintf(stderr, "Usage: %s models/mnist/mnist-cnn.gguf models/mnist/t10k-images.idx3-ubyte\n", argv[0]);
+        exit(0);
+    }
+
+    uint8_t buf[784];
+    mnist_model model;
+    std::vector<float> digit;
+
+    // load the model
+    {
+        const int64_t t_start_us = ggml_time_us();
+
+        if (!mnist_model_load(argv[1], model)) {
+            fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, argv[1]);
+            return 1;
+        }
+
+        const int64_t t_load_us = ggml_time_us() - t_start_us;
+
+        fprintf(stdout, "%s: loaded model in %8.2f ms\n", __func__, t_load_us / 1000.0f);
+    }
+
+    // read a random digit from the test set
+    {
+        std::ifstream fin(argv[2], std::ios::binary);
+        if (!fin) {
+            fprintf(stderr, "%s: failed to open '%s'\n", __func__, argv[2]);
+            return 1;
+        }
+
+        // seek to a random digit: 16-byte header + 28*28 * (random 0 - 10000)
+        fin.seekg(16 + 784 * (rand() % 10000));
+        fin.read((char *) &buf, sizeof(buf));
+    }
+
+    // render the digit in ASCII
+    {
+        digit.resize(sizeof(buf));
+
+        for (int row = 0; row < 28; row++) {
+            for (int col = 0; col < 28; col++) {
+                fprintf(stderr, "%c ", (float)buf[row*28 + col] > 230 ? '*' : '_');
+                digit[row*28 + col] = ((float)buf[row*28 + col] / 255.0f);
+            }
+
+            fprintf(stderr, "\n");
+        }
+
+        fprintf(stderr, "\n");
+    }
+
+    const int prediction = mnist_eval(model, 1, digit, nullptr);
+    fprintf(stdout, "%s: predicted digit is %d\n", __func__, prediction);
+    ggml_free(model.ctx);
+    return 0;
+}
diff --git a/examples/mnist/mnist-cnn.py b/examples/mnist/mnist-cnn.py
new file mode 100755 (executable)
index 0000000..35dda60
--- /dev/null
@@ -0,0 +1,101 @@
+#!/usr/bin/env python3
+import sys
+import gguf
+import numpy as np
+from tensorflow import keras
+from tensorflow.keras import layers
+
+def train(model_name):
+    # Model / data parameters
+    num_classes = 10
+    input_shape = (28, 28, 1)
+
+    # Load the data and split it between train and test sets
+    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
+
+    # Scale images to the [0, 1] range
+    x_train = x_train.astype("float32") / 255
+    x_test = x_test.astype("float32") / 255
+    # Make sure images have shape (28, 28, 1)
+    x_train = np.expand_dims(x_train, -1)
+    x_test = np.expand_dims(x_test, -1)
+    print("x_train shape:", x_train.shape)
+    print(x_train.shape[0], "train samples")
+    print(x_test.shape[0], "test samples")
+
+    # convert class vectors to binary class matrices
+    y_train = keras.utils.to_categorical(y_train, num_classes)
+    y_test = keras.utils.to_categorical(y_test, num_classes)
+
+    model = keras.Sequential(
+        [
+            keras.Input(shape=input_shape),
+            layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
+            layers.MaxPooling2D(pool_size=(2, 2)),
+            layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
+            layers.MaxPooling2D(pool_size=(2, 2)),
+            layers.Flatten(),
+            layers.Dropout(0.5),
+            layers.Dense(num_classes, activation="softmax"),
+        ]
+    )
+
+    model.summary()
+    batch_size = 128
+    epochs = 15
+    model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
+    model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
+
+    score = model.evaluate(x_test, y_test, verbose=0)
+    print("Test loss:", score[0])
+    print("Test accuracy:", score[1])
+    model.save(model_name)
+    print("Keras model saved to '" + model_name + "'")
+
+def convert(model_name):
+    model = keras.models.load_model(model_name)
+    gguf_model_name = model_name + ".gguf"
+    gguf_writer = gguf.GGUFWriter(gguf_model_name, "mnist-cnn")
+
+    kernel1 = model.layers[0].weights[0].numpy()
+    kernel1 = np.moveaxis(kernel1, [2,3], [0,1])
+    kernel1 = kernel1.astype(np.float16)
+    gguf_writer.add_tensor("kernel1", kernel1, raw_shape=(32, 1, 3, 3))
+
+    bias1 = model.layers[0].weights[1].numpy()
+    bias1 = np.repeat(bias1, 26*26)
+    gguf_writer.add_tensor("bias1", bias1, raw_shape=(1, 32, 26, 26))
+
+    kernel2 = model.layers[2].weights[0].numpy()
+    kernel2 = np.moveaxis(kernel2, [0,1,2,3], [2,3,1,0])
+    kernel2 = kernel2.astype(np.float16)
+    gguf_writer.add_tensor("kernel2", kernel2, raw_shape=(64, 32, 3, 3))
+
+    bias2 = model.layers[2].weights[1].numpy()
+    bias2 = np.repeat(bias2, 11*11)
+    gguf_writer.add_tensor("bias2", bias2, raw_shape=(1, 64, 11, 11))
+
+    dense_w = model.layers[-1].weights[0].numpy()
+    dense_w = dense_w.transpose()
+    gguf_writer.add_tensor("dense_w", dense_w, raw_shape=(10, 1600))
+
+    dense_b = model.layers[-1].weights[1].numpy()
+    gguf_writer.add_tensor("dense_b", dense_b)
+
+    gguf_writer.write_header_to_file()
+    gguf_writer.write_kv_data_to_file()
+    gguf_writer.write_tensors_to_file()
+    gguf_writer.close()
+    print("Model converted and saved to '{}'".format(gguf_model_name))
+
+if __name__ == '__main__':
+    if len(sys.argv) < 3:
+        print("Usage: %s <train|convert> <model_name>".format(sys.argv[0]))
+        sys.exit(1)
+    if sys.argv[1] == 'train':
+        train(sys.argv[2])
+    elif sys.argv[1] == 'convert':
+        convert(sys.argv[2])
+    else:
+        print("Usage: %s <train|convert> <model_name>".format(sys.argv[0]))
+        sys.exit(1)