From: Daniel Bevenius Date: Wed, 11 Jun 2025 09:15:14 +0000 (+0200) Subject: mnist : use CMake to build mnist wasm example (#1269) X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=21303b6c844e9c5d9c5078bcd63e5dc4f6001619;p=pkg%2Fggml%2Fsources%2Fggml mnist : use CMake to build mnist wasm example (#1269) This commit updates the mnist examples to use CMake for building the WebAssembly (WASM) version of the MNIST example instead of the current emcc command. The motivation for this change is that using CMake should make it easier to maintin with regards to when changes in ggml occur they should not cause this example to break. Currently the emcc command is outdated and it was not clear how to updated it which is why this change was made. Resolves: https://github.com/ggml-org/ggml/issues/1264 --- diff --git a/examples/mnist/CMakeLists.txt b/examples/mnist/CMakeLists.txt index ef17a727..a4b5133a 100644 --- a/examples/mnist/CMakeLists.txt +++ b/examples/mnist/CMakeLists.txt @@ -18,3 +18,41 @@ target_link_libraries(${TEST_TARGET} PRIVATE ggml common mnist-common) set(TEST_TARGET mnist-train) add_executable(${TEST_TARGET} mnist-train.cpp) target_link_libraries(${TEST_TARGET} PRIVATE ggml common mnist-common) + + +# +# mnist-wasm +if (EMSCRIPTEN) + set(TARGET mnist) + + add_executable(${TARGET} mnist-common.cpp) + target_link_libraries(${TARGET} PRIVATE ggml ggml-cpu) + + set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \ + --bind \ + -s FORCE_FILESYSTEM=1 \ + -s USE_PTHREADS=1 \ + -s PTHREAD_POOL_SIZE=10 \ + -s ASSERTIONS=1 \ + -s WASM=1 \ + -s EXPORTED_RUNTIME_METHODS=\"['ccall', 'cwrap', 'setValue', 'getValue']\" \ + -s EXPORTED_FUNCTIONS=\"['_wasm_eval','_wasm_random_digit','_malloc','_free']\" \ + -s ALLOW_MEMORY_GROWTH=1 \ + --preload-file ${CMAKE_CURRENT_SOURCE_DIR}/mnist-f32.gguf@/ \ + --preload-file ${CMAKE_CURRENT_SOURCE_DIR}/t10k-images-idx3-ubyte@/ \ + ") + + # Copy output to web directory + add_custom_command( + TARGET ${TARGET} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${CMAKE_BINARY_DIR}/bin/mnist.js + ${CMAKE_CURRENT_SOURCE_DIR}/web/mnist.js + COMMAND ${CMAKE_COMMAND} -E copy + ${CMAKE_BINARY_DIR}/bin/mnist.wasm + ${CMAKE_CURRENT_SOURCE_DIR}/web/mnist.wasm + COMMAND ${CMAKE_COMMAND} -E copy + ${CMAKE_BINARY_DIR}/bin/mnist.worker.js + ${CMAKE_CURRENT_SOURCE_DIR}/web/mnist.worker.js + ) +endif() diff --git a/examples/mnist/README.md b/examples/mnist/README.md index 51e394ef..af80261e 100644 --- a/examples/mnist/README.md +++ b/examples/mnist/README.md @@ -178,18 +178,23 @@ Symlinking these files will *not* work! Compile the code like so: ```bash -$ emcc -I../../include -I../../include/ggml -I../../examples ../../src/ggml.c ../../src/ggml-quants.c ../../src/ggml-aarch64.c mnist-common.cpp -o web/mnist.js -s EXPORTED_FUNCTIONS='["_wasm_eval","_wasm_random_digit","_malloc","_free"]' -s EXPORTED_RUNTIME_METHODS='["ccall"]' -s ALLOW_MEMORY_GROWTH=1 --preload-file mnist-f32.gguf --preload-file t10k-images-idx3-ubyte +$ cd ../../ +$ mkdir -p build-em +$ emcmake cmake .. -DGGML_BUILD_EXAMPLES=ON \ + -DCMAKE_C_FLAGS="-pthread -matomics -mbulk-memory" \ + -DCMAKE_CXX_FLAGS="-pthread -matomics -mbulk-memory" +$ make mnist ``` -The compilation output is in `examples/mnist/web`. +The compilation output is copied into `examples/mnist/web`. To run it, you need an HTTP server. For example: ``` bash -$ cd web -$ python3 -m http.server +$ python3 examples/mnist/server.py -Serving HTTP on 0.0.0.0 port 8000 (http://0.0.0.0:8000/) ... +Serving directory '/home/danbev/work/ai/ggml/examples/mnist/web' at http://localhost:8000 +Application context root: http://localhost:8000/ ``` The web demo can then be accessed via the link printed on the console. diff --git a/examples/mnist/mnist-common.cpp b/examples/mnist/mnist-common.cpp index a303bcec..30115163 100644 --- a/examples/mnist/mnist-common.cpp +++ b/examples/mnist/mnist-common.cpp @@ -227,7 +227,8 @@ mnist_model mnist_model_init_from_file(const std::string & fname, const std::str // The space in ctx_gguf exactly fits the model weights, // the images (which also need to be statically allocated) need to be put in a different context. - model.images = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, MNIST_NINPUT, MNIST_NBATCH_PHYSICAL); + model.images = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, MNIST_NINPUT, nbatch_physical); + ggml_set_name(model.images, "images"); ggml_set_input(model.images); @@ -458,7 +459,11 @@ int wasm_eval(uint8_t * digitPtr) { ggml_opt_dataset_t dataset = ggml_opt_dataset_init(GGML_TYPE_F32, GGML_TYPE_F32, MNIST_NINPUT, MNIST_NCLASSES, 1, 1); struct ggml_tensor * data = ggml_opt_dataset_data(dataset); - memcpy(data->data, digitPtr, ggml_nbytes(data)); + + float * buf = ggml_get_data_f32(data); + for (int i = 0; i < MNIST_NINPUT; ++i) { + buf[i] = digitPtr[i] / 255.0f; + } ggml_set_zero(ggml_opt_dataset_labels(dataset)); // The labels are not needed. mnist_model model = mnist_model_init_from_file("mnist-f32.gguf", "CPU", /*nbatch_logical =*/ 1, /*nbatch_physical =*/ 1); diff --git a/examples/mnist/server.py b/examples/mnist/server.py new file mode 100644 index 00000000..588b3968 --- /dev/null +++ b/examples/mnist/server.py @@ -0,0 +1,36 @@ +import http.server +import socketserver +import os +import sys + +DIRECTORY = os.path.abspath(os.path.join(os.path.dirname(__file__), 'web')) +PORT = 8000 + +class CustomHTTPRequestHandler(http.server.SimpleHTTPRequestHandler): + def __init__(self, *args, **kwargs): + super().__init__(*args, directory=DIRECTORY, **kwargs) + + def end_headers(self): + # Add required headers for SharedArrayBuffer + self.send_header("Cross-Origin-Opener-Policy", "same-origin") + self.send_header("Cross-Origin-Embedder-Policy", "require-corp") + self.send_header("Access-Control-Allow-Origin", "*") + super().end_headers() + +# Enable address reuse +class CustomServer(socketserver.TCPServer): + allow_reuse_address = True + +try: + with CustomServer(("", PORT), CustomHTTPRequestHandler) as httpd: + print(f"Serving directory '{DIRECTORY}' at http://localhost:{PORT}") + print(f"Application context root: http://localhost:{PORT}/") + try: + httpd.serve_forever() + except KeyboardInterrupt: + print("\nServer stopped.") + # Force complete exit + sys.exit(0) +except OSError as e: + print(f"Error: {e}") + sys.exit(1)