]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
mnist : use CMake to build mnist wasm example (#1269)
authorDaniel Bevenius <redacted>
Wed, 11 Jun 2025 09:15:14 +0000 (11:15 +0200)
committerGitHub <redacted>
Wed, 11 Jun 2025 09:15:14 +0000 (11:15 +0200)
This commit updates the mnist examples to use CMake for building the
WebAssembly (WASM) version of the MNIST example instead of the current
emcc command.

The motivation for this change is that using CMake should make it easier
to maintin with regards to when changes in ggml occur they should not
cause this example to break. Currently the emcc command is outdated and
it was not clear how to updated it which is why this change was made.

Resolves: https://github.com/ggml-org/ggml/issues/1264

examples/mnist/CMakeLists.txt
examples/mnist/README.md
examples/mnist/mnist-common.cpp
examples/mnist/server.py [new file with mode: 0644]

index ef17a7273f1459cfff498c00a546b284567fba4d..a4b5133aa1b6c970de3b6141660248c36cce24a8 100644 (file)
@@ -18,3 +18,41 @@ target_link_libraries(${TEST_TARGET} PRIVATE ggml common mnist-common)
 set(TEST_TARGET mnist-train)
 add_executable(${TEST_TARGET} mnist-train.cpp)
 target_link_libraries(${TEST_TARGET} PRIVATE ggml common mnist-common)
+
+
+#
+# mnist-wasm
+if (EMSCRIPTEN)
+    set(TARGET mnist)
+
+    add_executable(${TARGET} mnist-common.cpp)
+    target_link_libraries(${TARGET} PRIVATE ggml ggml-cpu)
+
+    set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
+        --bind \
+        -s FORCE_FILESYSTEM=1 \
+        -s USE_PTHREADS=1 \
+        -s PTHREAD_POOL_SIZE=10 \
+        -s ASSERTIONS=1 \
+        -s WASM=1 \
+        -s EXPORTED_RUNTIME_METHODS=\"['ccall', 'cwrap', 'setValue', 'getValue']\" \
+        -s EXPORTED_FUNCTIONS=\"['_wasm_eval','_wasm_random_digit','_malloc','_free']\" \
+        -s ALLOW_MEMORY_GROWTH=1 \
+        --preload-file ${CMAKE_CURRENT_SOURCE_DIR}/mnist-f32.gguf@/ \
+        --preload-file ${CMAKE_CURRENT_SOURCE_DIR}/t10k-images-idx3-ubyte@/ \
+        ")
+
+    # Copy output to web directory
+    add_custom_command(
+        TARGET ${TARGET} POST_BUILD
+        COMMAND ${CMAKE_COMMAND} -E copy
+            ${CMAKE_BINARY_DIR}/bin/mnist.js
+            ${CMAKE_CURRENT_SOURCE_DIR}/web/mnist.js
+        COMMAND ${CMAKE_COMMAND} -E copy
+            ${CMAKE_BINARY_DIR}/bin/mnist.wasm
+            ${CMAKE_CURRENT_SOURCE_DIR}/web/mnist.wasm
+        COMMAND ${CMAKE_COMMAND} -E copy
+            ${CMAKE_BINARY_DIR}/bin/mnist.worker.js
+            ${CMAKE_CURRENT_SOURCE_DIR}/web/mnist.worker.js
+        )
+endif()
index 51e394efedf2e996f8a9b88ef29b7c03480338e1..af80261e8ec1e6e9672a4c8ae5529a300bf29149 100644 (file)
@@ -178,18 +178,23 @@ Symlinking these files will *not* work!
 Compile the code like so:
 
 ```bash
-$ emcc -I../../include -I../../include/ggml -I../../examples ../../src/ggml.c ../../src/ggml-quants.c ../../src/ggml-aarch64.c mnist-common.cpp -o web/mnist.js -s EXPORTED_FUNCTIONS='["_wasm_eval","_wasm_random_digit","_malloc","_free"]' -s EXPORTED_RUNTIME_METHODS='["ccall"]' -s ALLOW_MEMORY_GROWTH=1 --preload-file mnist-f32.gguf --preload-file t10k-images-idx3-ubyte
+$ cd ../../
+$ mkdir -p build-em
+$ emcmake cmake .. -DGGML_BUILD_EXAMPLES=ON \
+    -DCMAKE_C_FLAGS="-pthread -matomics -mbulk-memory" \
+    -DCMAKE_CXX_FLAGS="-pthread -matomics -mbulk-memory"
+$ make mnist
 ```
 
-The compilation output is in `examples/mnist/web`.
+The compilation output is copied into `examples/mnist/web`.
 To run it, you need an HTTP server.
 For example:
 
 ``` bash
-$ cd web
-$ python3 -m http.server
+$ python3 examples/mnist/server.py
 
-Serving HTTP on 0.0.0.0 port 8000 (http://0.0.0.0:8000/) ...
+Serving directory '/home/danbev/work/ai/ggml/examples/mnist/web' at http://localhost:8000
+Application context root: http://localhost:8000/
 ```
 
 The web demo can then be accessed via the link printed on the console.
index a303bcec085b3c8c3db53ba698318d492c32f151..301151630c09558975614c58bcbee0a0b4f241d2 100644 (file)
@@ -227,7 +227,8 @@ mnist_model mnist_model_init_from_file(const std::string & fname, const std::str
     // The space in ctx_gguf exactly fits the model weights,
     // the images (which also need to be statically allocated) need to be put in a different context.
 
-    model.images = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, MNIST_NINPUT, MNIST_NBATCH_PHYSICAL);
+    model.images = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, MNIST_NINPUT, nbatch_physical);
+
     ggml_set_name(model.images, "images");
     ggml_set_input(model.images);
 
@@ -458,7 +459,11 @@ int wasm_eval(uint8_t * digitPtr) {
 
     ggml_opt_dataset_t dataset = ggml_opt_dataset_init(GGML_TYPE_F32, GGML_TYPE_F32, MNIST_NINPUT, MNIST_NCLASSES, 1, 1);
     struct ggml_tensor * data = ggml_opt_dataset_data(dataset);
-    memcpy(data->data, digitPtr, ggml_nbytes(data));
+
+    float * buf = ggml_get_data_f32(data);
+    for (int i = 0; i < MNIST_NINPUT; ++i) {
+        buf[i] = digitPtr[i] / 255.0f;
+    }
     ggml_set_zero(ggml_opt_dataset_labels(dataset)); // The labels are not needed.
 
     mnist_model model = mnist_model_init_from_file("mnist-f32.gguf", "CPU", /*nbatch_logical =*/ 1, /*nbatch_physical =*/ 1);
diff --git a/examples/mnist/server.py b/examples/mnist/server.py
new file mode 100644 (file)
index 0000000..588b396
--- /dev/null
@@ -0,0 +1,36 @@
+import http.server
+import socketserver
+import os
+import sys
+
+DIRECTORY = os.path.abspath(os.path.join(os.path.dirname(__file__), 'web'))
+PORT = 8000
+
+class CustomHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, directory=DIRECTORY, **kwargs)
+
+    def end_headers(self):
+        # Add required headers for SharedArrayBuffer
+        self.send_header("Cross-Origin-Opener-Policy", "same-origin")
+        self.send_header("Cross-Origin-Embedder-Policy", "require-corp")
+        self.send_header("Access-Control-Allow-Origin", "*")
+        super().end_headers()
+
+# Enable address reuse
+class CustomServer(socketserver.TCPServer):
+    allow_reuse_address = True
+
+try:
+    with CustomServer(("", PORT), CustomHTTPRequestHandler) as httpd:
+        print(f"Serving directory '{DIRECTORY}' at http://localhost:{PORT}")
+        print(f"Application context root: http://localhost:{PORT}/")
+        try:
+            httpd.serve_forever()
+        except KeyboardInterrupt:
+            print("\nServer stopped.")
+            # Force complete exit
+            sys.exit(0)
+except OSError as e:
+    print(f"Error: {e}")
+    sys.exit(1)