<div>
<button id="clear" onclick="onClear()">Clear</button>
<button id="random" onclick="onRandom()" disabled>Random</button>
+ <button id="download" onclick="onDownload()">Download</button>
</div>
<div>
<p id="prediction"></p>
"use strict";
const DIGIT_SIZE = 28; // digits are 28x28 pixels
var canvas = document.getElementById("ggCanvas");
-var ctx = canvas.getContext("2d");
-var digit = new Array(DIGIT_SIZE*DIGIT_SIZE).fill(0);
+var ctx = canvas.getContext("2d", { alpha: false, willReadFrequently: true });
+ctx.fillStyle = "white";
+ctx.fillRect(0, 0, canvas.width, canvas.height);
var dragging = false;
+var lastX, lastY;
function onClear(event) {
- ctx.clearRect(0, 0, canvas.width, canvas.height);
- digit.fill(0);
+ ctx.fillStyle = "white";
+ ctx.fillRect(0, 0, canvas.width, canvas.height);
document.getElementById("prediction").innerHTML = "";
}
+function predict(digit) {
+ let buf = Module._malloc(digit.length);
+ if (buf == 0) {
+ console.log("failed to allocate memory");
+ return;
+ }
+ Module.HEAPU8.set(digit, buf);
+ let prediction = Module.ccall('wasm_eval', null, ['number'], [buf]);
+ Module._free(buf);
+ if (prediction >= 0) {
+ document.getElementById("prediction").innerHTML = "Predicted digit is <b>" + prediction + "</b>";
+ }
+}
+
function onRandom(event) {
onClear();
- var buf = Module._malloc(digit.length);
+ const bufLength = DIGIT_SIZE*DIGIT_SIZE;
+ var buf = Module._malloc(bufLength);
if (buf == 0) {
console.log("failed to allocate memory");
return;
}
let ret = Module.ccall('wasm_random_digit', null, ['number'], [buf]);
- let digitBytes = new Uint8Array(Module.HEAPU8.buffer, buf, digit.length);
+ let digit = new Uint8Array(Module.HEAPU8.buffer, buf, bufLength);
for (let i = 0; i < digit.length; i++) {
- digit[i] = digitBytes[i];
let x = i % DIGIT_SIZE;
let y = Math.floor(i / DIGIT_SIZE);
setPixel(x, y, digit[i]);
}
Module._free(buf);
- onMouseUp();
+ predict(digit);
+}
+
+function onDownload(event) {
+ let digit = scaleCanvas();
+ let digitBlob = new Blob([new Uint8Array(digit)], {type: "application/octet-stream"});
+ let url = URL.createObjectURL(digitBlob);
+ let link = document.createElement('a');
+ link.href = url;
+ link.download = "image.raw";
+ document.body.appendChild(link);
+ link.click();
+ document.body.removeChild(link);
}
// Get the position of the mouse relative to the canvas
}
function setPixel(x, y, val) {
- digit[y * DIGIT_SIZE + x] = val;
let canvasX = x * 13;
let canvasY = y * 13;
let color = 255 - val;
function onMouseDown(e) {
dragging = true;
- let [mouseX, mouseY] = getMousePos(e);
- setPixel(Math.floor(mouseX / 13), Math.floor(mouseY / 13), 255);
+ [lastX, lastY] = getMousePos(e);
+}
+
+// scale the canvas to 28x28 pixels and return the pixel values as an array
+function scaleCanvas() {
+ let imgData = ctx.getImageData(0, 0, canvas.width, canvas.height);
+ let tempCanvas = document.createElement('canvas');
+ tempCanvas.width = DIGIT_SIZE;
+ tempCanvas.height = DIGIT_SIZE;
+ let tempCtx = tempCanvas.getContext("2d");
+ tempCtx.drawImage(canvas, 0, 0, DIGIT_SIZE, DIGIT_SIZE);
+ let tempImgData = tempCtx.getImageData(0, 0, DIGIT_SIZE, DIGIT_SIZE);
+ let tempData = tempImgData.data;
+ let digit = new Array(DIGIT_SIZE*DIGIT_SIZE).fill(0);
+ for (let i = 0; i < tempData.length; i += 4) {
+ let val = 255 - tempData[i];
+ digit[i / 4] = val;
+ }
+ return digit;
}
function onMouseUp(e) {
dragging = false;
- var buf = Module._malloc(digit.length);
- if (buf == 0) {
- console.log("failed to allocate memory");
- return;
- }
- Module.HEAPU8.set(digit, buf);
- let prediction = Module.ccall('wasm_eval', null, ['number'], [buf]);
- Module._free(buf);
- if (prediction >= 0) {
- document.getElementById("prediction").innerHTML = "Predicted digit is <b>" + prediction + "</b>";
- }
+ let digit = scaleCanvas();
+ predict(digit);
}
+
function onMouseMove(e) {
if (dragging) {
let [mouseX, mouseY] = getMousePos(e);
- setPixel(Math.floor(mouseX / 13), Math.floor(mouseY / 13), 255);
+ ctx.beginPath();
+ ctx.moveTo(lastX, lastY);
+ ctx.lineTo(mouseX, mouseY);
+ ctx.lineWidth = 20;
+ ctx.lineJoin = ctx.lineCap = 'round';
+ ctx.strokeStyle = "#000000";
+ ctx.stroke();
+ ctx.closePath();
+ lastX = mouseX;
+ lastY = mouseY;
}
}