]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
mnist : smooth user input (#199)
authorRadoslav Gerganov <redacted>
Fri, 26 May 2023 08:53:18 +0000 (11:53 +0300)
committerGitHub <redacted>
Fri, 26 May 2023 08:53:18 +0000 (11:53 +0300)
Drawing on the canvas is now smooth. The final image which is used for
prediction is obtained by down-scaling the canvas to 28x28 pixels.
Download button is aslo added for downloading raw image values.

examples/mnist/web/index.html

index 1bd01ae500bcc76c46476da8499bdaddbd41f3a5..ab1ef1778becd0eedd2ae03ec7a5e10f1999c63a 100644 (file)
@@ -15,6 +15,7 @@
     <div>
         <button id="clear" onclick="onClear()">Clear</button>
         <button id="random" onclick="onRandom()" disabled>Random</button>
+        <button id="download" onclick="onDownload()">Download</button>
     </div>
     <div>
         <p id="prediction"></p>
 "use strict";
 const DIGIT_SIZE = 28; // digits are 28x28 pixels
 var canvas = document.getElementById("ggCanvas");
-var ctx = canvas.getContext("2d");
-var digit = new Array(DIGIT_SIZE*DIGIT_SIZE).fill(0);
+var ctx = canvas.getContext("2d", { alpha: false, willReadFrequently: true });
+ctx.fillStyle = "white";
+ctx.fillRect(0, 0, canvas.width, canvas.height);
 var dragging = false;
+var lastX, lastY;
 
 function onClear(event) {
-    ctx.clearRect(0, 0, canvas.width, canvas.height);
-    digit.fill(0);
+    ctx.fillStyle = "white";
+    ctx.fillRect(0, 0, canvas.width, canvas.height);
     document.getElementById("prediction").innerHTML = "";
 }
 
+function predict(digit) {
+    let buf = Module._malloc(digit.length);
+    if (buf == 0) {
+        console.log("failed to allocate memory");
+        return;
+    }
+    Module.HEAPU8.set(digit, buf);
+    let prediction = Module.ccall('wasm_eval', null, ['number'], [buf]);
+    Module._free(buf);
+    if (prediction >= 0) {
+        document.getElementById("prediction").innerHTML = "Predicted digit is <b>" + prediction + "</b>";
+    }
+}
+
 function onRandom(event) {
     onClear();
-    var buf = Module._malloc(digit.length);
+    const bufLength = DIGIT_SIZE*DIGIT_SIZE;
+    var buf = Module._malloc(bufLength);
     if (buf == 0) {
         console.log("failed to allocate memory");
         return;
     }
     let ret = Module.ccall('wasm_random_digit', null, ['number'], [buf]);
-    let digitBytes = new Uint8Array(Module.HEAPU8.buffer, buf, digit.length);
+    let digit = new Uint8Array(Module.HEAPU8.buffer, buf, bufLength);
     for (let i = 0; i < digit.length; i++) {
-        digit[i] = digitBytes[i];
         let x = i % DIGIT_SIZE;
         let y = Math.floor(i / DIGIT_SIZE);
         setPixel(x, y, digit[i]);
     }
     Module._free(buf);
-    onMouseUp();
+    predict(digit);
+}
+
+function onDownload(event) {
+    let digit = scaleCanvas();
+    let digitBlob = new Blob([new Uint8Array(digit)], {type: "application/octet-stream"});
+    let url = URL.createObjectURL(digitBlob);
+    let link = document.createElement('a');
+    link.href = url;
+    link.download = "image.raw";
+    document.body.appendChild(link);
+    link.click();
+    document.body.removeChild(link);
 }
 
 // Get the position of the mouse relative to the canvas
@@ -62,7 +91,6 @@ function getMousePos(event) {
 }
 
 function setPixel(x, y, val) {
-    digit[y * DIGIT_SIZE + x] = val;
     let canvasX = x * 13;
     let canvasY = y * 13;
     let color = 255 - val;
@@ -72,28 +100,46 @@ function setPixel(x, y, val) {
 
 function onMouseDown(e) {
     dragging = true;
-    let [mouseX, mouseY] = getMousePos(e);
-    setPixel(Math.floor(mouseX / 13), Math.floor(mouseY / 13), 255);
+    [lastX, lastY] = getMousePos(e);
+}
+
+// scale the canvas to 28x28 pixels and return the pixel values as an array
+function scaleCanvas() {
+    let imgData = ctx.getImageData(0, 0, canvas.width, canvas.height);
+    let tempCanvas = document.createElement('canvas');
+    tempCanvas.width = DIGIT_SIZE;
+    tempCanvas.height = DIGIT_SIZE;
+    let tempCtx = tempCanvas.getContext("2d");
+    tempCtx.drawImage(canvas, 0, 0, DIGIT_SIZE, DIGIT_SIZE);
+    let tempImgData = tempCtx.getImageData(0, 0, DIGIT_SIZE, DIGIT_SIZE);
+    let tempData = tempImgData.data;
+    let digit = new Array(DIGIT_SIZE*DIGIT_SIZE).fill(0);
+    for (let i = 0; i < tempData.length; i += 4) {
+        let val = 255 - tempData[i];
+        digit[i / 4] = val;
+    }
+    return digit;
 }
 
 function onMouseUp(e) {
     dragging = false;
-    var buf = Module._malloc(digit.length);
-    if (buf == 0) {
-        console.log("failed to allocate memory");
-        return;
-    }
-    Module.HEAPU8.set(digit, buf);
-    let prediction = Module.ccall('wasm_eval', null, ['number'], [buf]);
-    Module._free(buf);
-    if (prediction >= 0) {
-        document.getElementById("prediction").innerHTML = "Predicted digit is <b>" + prediction + "</b>";
-    }
+    let digit = scaleCanvas();
+    predict(digit);
 }
+
 function onMouseMove(e) {
     if (dragging) {
         let [mouseX, mouseY] = getMousePos(e);
-        setPixel(Math.floor(mouseX / 13), Math.floor(mouseY / 13), 255);
+        ctx.beginPath();
+        ctx.moveTo(lastX, lastY);
+        ctx.lineTo(mouseX, mouseY);
+        ctx.lineWidth = 20;
+        ctx.lineJoin = ctx.lineCap = 'round';
+        ctx.strokeStyle = "#000000";
+        ctx.stroke();
+        ctx.closePath();
+        lastX = mouseX;
+        lastY = mouseY;
     }
 }