]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
mnist : add web page for the MNIST example (#190)
authorRadoslav Gerganov <redacted>
Wed, 24 May 2023 08:40:47 +0000 (11:40 +0300)
committerGitHub <redacted>
Wed, 24 May 2023 08:40:47 +0000 (11:40 +0300)
The web page is using WASM for model inference.
Users can draw digits on an HTML canvas and load random digits from the
MNIST dataset.

examples/mnist/main.cpp
examples/mnist/web/index.html [new file with mode: 0644]

index fb052491bf688e6b433f8664792c0205c51b2316..60a8fa716569eea5464865053603c0e3682584ce 100644 (file)
@@ -208,6 +208,41 @@ int mnist_eval(
     return prediction;
 }
 
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+int wasm_eval(uint8_t *digitPtr)
+{
+    mnist_model model;
+    if (!mnist_model_load("models/mnist/ggml-model-f32.bin", model)) {
+        fprintf(stderr, "error loading model\n");
+        return -1;
+    }
+    std::vector<float> digit(digitPtr, digitPtr + 784);
+    int result = mnist_eval(model, 1, digit);
+    ggml_free(model.ctx);
+    return result;
+}
+
+int wasm_random_digit(char *digitPtr)
+{
+    auto fin = std::ifstream("models/mnist/t10k-images.idx3-ubyte", std::ios::binary);
+    if (!fin) {
+        fprintf(stderr, "failed to open digits file\n");
+        return 0;
+    }
+    srand(time(NULL));
+    // Seek to a random digit: 16-byte header + 28*28 * (random 0 - 10000)
+    fin.seekg(16 + 784 * (rand() % 10000));
+    fin.read(digitPtr, 784);
+    return 1;
+}
+
+#ifdef __cplusplus
+}
+#endif
+
 int main(int argc, char ** argv) {
     srand(time(NULL));
     ggml_time_init();
diff --git a/examples/mnist/web/index.html b/examples/mnist/web/index.html
new file mode 100644 (file)
index 0000000..d62bead
--- /dev/null
@@ -0,0 +1,126 @@
+<!DOCTYPE html>
+<html>
+<head>
+    <meta charset="utf-8">
+    <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
+    <title>MNIST with GGML</title>
+    <script src="mnist.js"></script>
+</head>
+<body>
+    <h2>MNIST digit recognizer with GGML</h2>
+    <p>Draw a single digit on the canvas below:</p>
+    <canvas id="ggCanvas" width="364" height="364" style="border:2px solid #d3d3d3;">
+        Your browser does not support the HTML canvas tag.
+    </canvas>
+    <div>
+        <button id="clear" onclick="onClear()">Clear</button>
+        <button id="random" onclick="onRandom()">Random</button>
+    </div>
+    <div>
+        <p id="prediction"></p>
+    </div>
+    <script>
+"use strict";
+const DIGIT_SIZE = 28; // digits are 28x28 pixels
+var canvas = document.getElementById("ggCanvas");
+var ctx = canvas.getContext("2d");
+var digit = new Array(DIGIT_SIZE*DIGIT_SIZE).fill(0);
+var dragging = false;
+
+function onClear(event) {
+    ctx.clearRect(0, 0, canvas.width, canvas.height);
+    digit.fill(0);
+    document.getElementById("prediction").innerHTML = "";
+}
+
+function onRandom(event) {
+    onClear();
+    var buf = Module._malloc(digit.length);
+    if (buf == 0) {
+        console.log("failed to allocate memory");
+        return;
+    }
+    let ret = Module.ccall('wasm_random_digit', null, ['number'], [buf]);
+    let digitBytes = new Uint8Array(Module.HEAPU8.buffer, buf, digit.length);
+    for (let i = 0; i < digit.length; i++) {
+        digit[i] = digitBytes[i];
+        let x = i % DIGIT_SIZE;
+        let y = Math.floor(i / DIGIT_SIZE);
+        setPixel(x, y, digit[i]);
+    }
+    Module._free(buf);
+    onMouseUp();
+}
+
+// Get the position of the mouse relative to the canvas
+function getMousePos(event) {
+    if (event.touches !== undefined && event.touches.length > 0) {
+        event = event.touches[0];
+    }
+    var rect = canvas.getBoundingClientRect();
+    return [Math.floor(event.clientX) - rect.left, Math.floor(event.clientY) - rect.top];
+}
+
+function setPixel(x, y, val) {
+    digit[y * DIGIT_SIZE + x] = val;
+    let canvasX = x * 13;
+    let canvasY = y * 13;
+    let color = 255 - val;
+    ctx.fillStyle = "#" + color.toString(16) + color.toString(16) + color.toString(16);
+    ctx.fillRect(canvasX, canvasY, 13, 13);
+}
+
+function onMouseDown(e) {
+    dragging = true;
+    let [mouseX, mouseY] = getMousePos(e);
+    setPixel(Math.floor(mouseX / 13), Math.floor(mouseY / 13), 255);
+}
+
+function onMouseUp(e) {
+    dragging = false;
+    var buf = Module._malloc(digit.length);
+    if (buf == 0) {
+        console.log("failed to allocate memory");
+        return;
+    }
+    Module.HEAPU8.set(digit, buf);
+    let prediction = Module.ccall('wasm_eval', null, ['number'], [buf]);
+    Module._free(buf);
+    if (prediction >= 0) {
+        document.getElementById("prediction").innerHTML = "Predicted digit is <b>" + prediction + "</b>";
+    }
+}
+function onMouseMove(e) {
+    if (dragging) {
+        let [mouseX, mouseY] = getMousePos(e);
+        setPixel(Math.floor(mouseX / 13), Math.floor(mouseY / 13), 255);
+    }
+}
+
+// Prevent scrolling when touching the canvas
+document.body.addEventListener("touchstart", function (e) {
+if (e.target == canvas) {
+    e.preventDefault();
+}
+}, {passive: false});
+document.body.addEventListener("touchend", function (e) {
+if (e.target == canvas) {
+    e.preventDefault();
+}
+}, {passive: false});
+document.body.addEventListener("touchmove", function (e) {
+if (e.target == canvas) {
+    e.preventDefault();
+}
+}, {passive: false});
+
+// Use the same handlers for mouse and touch events
+canvas.onmousedown = onMouseDown;
+canvas.onmouseup = onMouseUp;
+canvas.onmousemove = onMouseMove;
+canvas.ontouchstart = onMouseDown;
+canvas.ontouchend = onMouseUp;
+canvas.ontouchmove = onMouseMove;
+    </script>
+</body>
+</html>