]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml: WebGPU backend host improvements and style fixing (#14978)
authorReese Levine <redacted>
Mon, 4 Aug 2025 15:52:43 +0000 (08:52 -0700)
committerGitHub <redacted>
Mon, 4 Aug 2025 15:52:43 +0000 (08:52 -0700)
* Add parameter buffer pool, batching of submissions, refactor command building/submission

* Add header for linux builds

* Free staged parameter buffers at once

* Format with clang-format

* Fix thread-safe implementation

* Use device implicit synchronization

* Update workflow to use custom release

* Remove testing branch workflow

.github/workflows/build.yml
ggml/src/ggml-webgpu/ggml-webgpu.cpp

index c6d51fb0c2e7ed326de64d0ba2c921cf9df79a48..3d4f837e2489506449dc584783ed75bf14e703c3 100644 (file)
@@ -159,31 +159,15 @@ jobs:
       - name: Dawn Dependency
         id: dawn-depends
         run: |
-          ARTIFACTS_JSON=$(curl -s -L \
-            -H "Accept: application/vnd.github+json" \
-            -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
-            -H "X-GitHub-Api-Version: 2022-11-28" \
-            "https://api.github.com/repos/google/dawn/actions/artifacts")
-          echo "Finding latest macos-latest-Release artifact..."
-          DOWNLOAD_URL=$(echo "$ARTIFACTS_JSON" | jq -r '.artifacts
-            | sort_by(.created_at)
-            | reverse
-            | map(select(.name | test("macos-latest-Release$")))
-            | .[0].archive_download_url')
-          if [ "$DOWNLOAD_URL" = "null" ] || [ -z "$DOWNLOAD_URL" ]; then
-            echo "No suitable Dawn artifact found!"
-            exit 1
-          fi
-          echo "Downloading from: $DOWNLOAD_URL"
-          curl -L \
-            -H "Accept: application/vnd.github+json" \
-            -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
-            -o artifact.zip "$DOWNLOAD_URL"
-          unzip artifact.zip
+          DAWN_VERSION="v1.0.0"
+          DAWN_OWNER="reeselevine"
+          DAWN_REPO="dawn"
+          DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-macos-latest-Release.tar.gz"
+          echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
+          curl -L -o artifact.tar.gz \
+            "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
           mkdir dawn
-          tar_file=$(find . -name '*.tar.gz' | head -n 1)
-          echo "Extracting: $tar_file"
-          tar -xvf "$tar_file" -C dawn --strip-components=1
+          tar -xvf artifact.tar.gz -C dawn --strip-components=1
 
       - name: Build
         id: cmake_build
@@ -433,31 +417,15 @@ jobs:
         id: dawn-depends
         run: |
           sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev
-          ARTIFACTS_JSON=$(curl -s -L \
-            -H "Accept: application/vnd.github+json" \
-            -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
-            -H "X-GitHub-Api-Version: 2022-11-28" \
-            "https://api.github.com/repos/google/dawn/actions/artifacts")
-          echo "Finding latest ubuntu-latest-Release artifact..."
-          DOWNLOAD_URL=$(echo "$ARTIFACTS_JSON" | jq -r '.artifacts
-            | sort_by(.created_at)
-            | reverse
-            | map(select(.name | test("ubuntu-latest-Release$")))
-            | .[0].archive_download_url')
-          if [ "$DOWNLOAD_URL" = "null" ] || [ -z "$DOWNLOAD_URL" ]; then
-            echo "No suitable Dawn artifact found!"
-            exit 1
-          fi
-          echo "Downloading from: $DOWNLOAD_URL"
-          curl -L \
-            -H "Accept: application/vnd.github+json" \
-            -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
-            -o artifact.zip "$DOWNLOAD_URL"
-          unzip artifact.zip
+          DAWN_VERSION="v1.0.0"
+          DAWN_OWNER="reeselevine"
+          DAWN_REPO="dawn"
+          DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-ubuntu-latest-Release.tar.gz"
+          echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
+          curl -L -o artifact.tar.gz \
+            "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
           mkdir dawn
-          tar_file=$(find . -name '*.tar.gz' | head -n 1)
-          echo "Extracting: $tar_file"
-          tar -xvf "$tar_file" -C dawn --strip-components=1
+          tar -xvf artifact.tar.gz -C dawn --strip-components=1
 
       - name: Build
         id: cmake_build
index c5abc69343357fa501b2b27566307f0a96e00dc9..91411d9c0014b351b21ed90d0dcad67f3cdd0bcb 100644 (file)
@@ -1,34 +1,41 @@
-#include "ggml-webgpu.h"
+/*
+    WebGPU backend implementation.
+    Note: Use ClangFormat to format this file.
+*/
 
-#include <webgpu/webgpu_cpp.h>
+#include "ggml-webgpu.h"
 
-#include "ggml-impl.h"
 #include "ggml-backend-impl.h"
-
+#include "ggml-impl.h"
 #include "ggml-wgsl-shaders.hpp"
 
+#include <webgpu/webgpu_cpp.h>
+
+#include <condition_variable>
 #include <cstring>
 #include <iostream>
 #include <mutex>
+#include <string>
 #include <vector>
 
 #ifdef GGML_WEBGPU_DEBUG
-#define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
+#    define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
 #else
-#define WEBGPU_LOG_DEBUG(msg) ((void) 0)
-#endif // GGML_WEBGPU_DEBUG
+#    define WEBGPU_LOG_DEBUG(msg) ((void) 0)
+#endif  // GGML_WEBGPU_DEBUG
 
 /* Constants */
 
-#define WEBGPU_MUL_MAT_WG_SIZE 64
-#define WEBGPU_MUL_MAT_PARAMS_SIZE (13 * sizeof(uint32_t)) // M, N, K, batch sizes, broadcasts
-#define WEBGPU_CPY_PARAMS_SIZE (15 * sizeof(uint32_t)) // strides and offsets
-#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
+#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
+#define WEBGPU_MUL_MAT_WG_SIZE           64
+#define WEBGPU_NUM_PARAM_BUFS            100
+#define WEBGPU_PARAMS_BUF_SIZE_BYTES     256
+#define WEBGPU_STORAGE_BUF_BINDING_MULT  4  // a storage buffer binding size must be a multiple of 4
 
 /* End Constants */
 
 // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
-static void * const webgpu_ptr_base = (void *)(uintptr_t) 0x1000;  // NOLINT
+static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000;  // NOLINT
 
 // Always returns the base offset of a tensor, regardless of views.
 static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
@@ -40,100 +47,172 @@ static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
 
 /* Struct definitions */
 
+// Forward reference
+static void ggml_webgpu_create_buffer(wgpu::Device &    device,
+                                      wgpu::Buffer &    buffer,
+                                      size_t            size,
+                                      wgpu::BufferUsage usage,
+                                      const char *      label);
+
+struct webgpu_param_bufs {
+    wgpu::Buffer host_buf;
+    wgpu::Buffer dev_buf;
+};
+
+// Holds a pool of parameter buffers for WebGPU operations
+struct webgpu_param_buf_pool {
+    std::vector<webgpu_param_bufs> free;
+
+    std::mutex mutex;
+
+    std::condition_variable cv;
+
+    void init(wgpu::Device device) {
+        for (int i = 0; i < WEBGPU_NUM_PARAM_BUFS; i++) {
+            wgpu::Buffer host_buf;
+            wgpu::Buffer dev_buf;
+            ggml_webgpu_create_buffer(device,
+                                      host_buf,
+                                      WEBGPU_PARAMS_BUF_SIZE_BYTES,
+                                      wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite,
+                                      "ggml_webgpu_host_params_buf");
+            ggml_webgpu_create_buffer(device,
+                                      dev_buf,
+                                      WEBGPU_PARAMS_BUF_SIZE_BYTES,
+                                      wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
+                                      "ggml_webgpu_dev_params_buf");
+            free.push_back({ host_buf, dev_buf });
+        }
+    }
+
+    webgpu_param_bufs alloc_bufs() {
+        std::unique_lock<std::mutex> lock(mutex);
+        cv.wait(lock, [this] { return !free.empty(); });
+        webgpu_param_bufs bufs = free.back();
+        free.pop_back();
+        return bufs;
+    }
+
+    void free_bufs(std::vector<webgpu_param_bufs> bufs) {
+        std::lock_guard<std::mutex> lock(mutex);
+        free.insert(free.end(), bufs.begin(), bufs.end());
+        cv.notify_all();
+    }
+
+    void cleanup() {
+        std::lock_guard<std::mutex> lock(mutex);
+        for (auto & bufs : free) {
+            bufs.host_buf.Destroy();
+            bufs.dev_buf.Destroy();
+        }
+        free.clear();
+    }
+};
+
 // All the base objects needed to run operations on a WebGPU device
 struct webgpu_context_struct {
     wgpu::Instance instance;
-    wgpu::Adapter adapter;
-    wgpu::Device device;
-    wgpu::Queue queue;
-    wgpu::Limits limits;
-    wgpu::SupportedFeatures features;
+    wgpu::Adapter  adapter;
+    wgpu::Device   device;
+    wgpu::Queue    queue;
+    wgpu::Limits   limits;
 
-    std::mutex mutex;
-    bool device_initialized = false;
+    std::recursive_mutex mutex;
+    std::mutex           get_tensor_mutex;
+    std::mutex           init_mutex;
+
+    bool device_init = false;
+
+    webgpu_param_buf_pool param_buf_pool;
 
-    // pipelines and parameter buffers
-    // TODO: reuse params buffers for different pipelines when possible
     wgpu::ComputePipeline memset_pipeline;
-    wgpu::Buffer memset_params_dev_buf;
-    wgpu::Buffer memset_params_host_buf;
     wgpu::ComputePipeline mul_mat_pipeline;
-    wgpu::Buffer mul_mat_params_dev_buf;
-    wgpu::Buffer mul_mat_params_host_buf;
     wgpu::ComputePipeline cpy_pipeline;
-    wgpu::Buffer cpy_params_dev_buf;
-    wgpu::Buffer cpy_params_host_buf;
 
     size_t memset_bytes_per_thread;
 
     // Staging buffer for reading data from the GPU
     wgpu::Buffer get_tensor_staging_buf;
+
+    // Command buffers which need to be submitted
+    std::vector<wgpu::CommandBuffer> staged_command_bufs;
+
+    // Parameter buffers associated with the staged command buffers
+    std::vector<webgpu_param_bufs> staged_param_bufs;
 };
 
 typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
 
 struct ggml_backend_webgpu_reg_context {
     webgpu_context webgpu_ctx;
-
-    size_t device_count;
-    const char * name;
+    size_t         device_count;
+    const char *   name;
 };
 
 struct ggml_backend_webgpu_device_context {
     webgpu_context webgpu_ctx;
-
-    std::string device_name;
-    std::string device_desc;
+    std::string    device_name;
+    std::string    device_desc;
 };
 
 struct ggml_backend_webgpu_context {
     webgpu_context webgpu_ctx;
-
-    std::string name;
+    std::string    name;
 };
 
 struct ggml_backend_webgpu_buffer_context {
     webgpu_context webgpu_ctx;
-
-    wgpu::Buffer buffer;
+    wgpu::Buffer   buffer;
 
     ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) :
-        webgpu_ctx(ctx), buffer(buf) {
-    }
+        webgpu_ctx(std::move(ctx)),
+        buffer(std::move(buf)) {}
 };
 
 /* End struct definitions */
 
 /* WebGPU object initializations */
 
-static void ggml_webgpu_create_pipeline(wgpu::Device &device, wgpu::ComputePipeline &pipeline, const char * shader_code, const char * label, const std::vector<wgpu::ConstantEntry> &constants = {}) {
+static void ggml_webgpu_create_pipeline(wgpu::Device &                           device,
+                                        wgpu::ComputePipeline &                  pipeline,
+                                        const char *                             shader_code,
+                                        const char *                             label,
+                                        const std::vector<wgpu::ConstantEntry> & constants = {}) {
     WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()");
+
     wgpu::ShaderSourceWGSL shader_source;
     shader_source.code = shader_code;
+
     wgpu::ShaderModuleDescriptor shader_desc;
     shader_desc.nextInChain = &shader_source;
+
     wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
 
     wgpu::ComputePipelineDescriptor pipeline_desc;
-    pipeline_desc.label = label;
-    pipeline_desc.compute.module = shader_module;
-    pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
-    pipeline_desc.layout = nullptr; // nullptr means auto layout
+    pipeline_desc.label              = label;
+    pipeline_desc.compute.module     = shader_module;
+    pipeline_desc.compute.entryPoint = "main";   // Entry point in the WGSL code
+    pipeline_desc.layout             = nullptr;  // nullptr means auto layout
     if (constants.size() > 0) {
-        pipeline_desc.compute.constants = constants.data();
+        pipeline_desc.compute.constants     = constants.data();
         pipeline_desc.compute.constantCount = constants.size();
     }
     pipeline = device.CreateComputePipeline(&pipeline_desc);
 }
 
-static void ggml_webgpu_create_buffer(wgpu::Device &device, wgpu::Buffer &buffer, size_t size, wgpu::BufferUsage usage, const char* label) {
+static void ggml_webgpu_create_buffer(wgpu::Device &    device,
+                                      wgpu::Buffer &    buffer,
+                                      size_t            size,
+                                      wgpu::BufferUsage usage,
+                                      const char *      label) {
     WEBGPU_LOG_DEBUG("ggml_webgpu_create_buffer()");
 
     wgpu::BufferDescriptor buffer_desc;
-    buffer_desc.size = size;
-    buffer_desc.usage = usage;
-    buffer_desc.label = label;
+    buffer_desc.size             = size;
+    buffer_desc.usage            = usage;
+    buffer_desc.label            = label;
     buffer_desc.mappedAtCreation = false;
+
     // TODO: error handling
     buffer = device.CreateBuffer(&buffer_desc);
 }
@@ -142,75 +221,133 @@ static void ggml_webgpu_create_buffer(wgpu::Device &device, wgpu::Buffer &buffer
 
 /** WebGPU Actions */
 
-static void ggml_backend_webgpu_map_buffer(webgpu_context ctx, wgpu::Buffer buffer, wgpu::MapMode mode, size_t offset, size_t size) {
-    ctx->instance.WaitAny(buffer.MapAsync(
-        mode, offset, size, wgpu::CallbackMode::WaitAnyOnly,
-        [](wgpu::MapAsyncStatus status, wgpu::StringView message) {
-            if (status != wgpu::MapAsyncStatus::Success) {
-                GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n", message.data);
+static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
+    // Wait for the queue to finish processing all commands
+    ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
+                              wgpu::CallbackMode::AllowSpontaneous,
+                              [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
+                                  if (status != wgpu::QueueWorkDoneStatus::Success) {
+                                      GGML_LOG_ERROR("ggml_webgpu: Failed to wait on queue: %s\n", message.data);
+                                  }
+                              }),
+                          UINT64_MAX);
+}
+
+static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
+    std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
+    ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data());
+    ctx->staged_command_bufs.clear();
+    std::vector<webgpu_param_bufs> staged_param_bufs = std::move(ctx->staged_param_bufs);
+    // Free the staged parameter buffers once the submission completes
+    ctx->queue.OnSubmittedWorkDone(
+        wgpu::CallbackMode::AllowSpontaneous,
+        [ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
+            if (status != wgpu::QueueWorkDoneStatus::Success) {
+                GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data);
             }
-        }),
-        UINT64_MAX
-    );
-}
-
-static void ggml_backend_webgpu_buffer_memset(webgpu_context ctx, wgpu::Buffer buf, uint32_t value, size_t offset, size_t size) {
-    std::lock_guard<std::mutex> lock(ctx->mutex);
-    wgpu::Device device = ctx->device;
-
-    // map the host parameters buffer
-    ggml_backend_webgpu_map_buffer(ctx, ctx->memset_params_host_buf, wgpu::MapMode::Write, 0, ctx->memset_params_host_buf.GetSize());
-    uint32_t * params = (uint32_t *) ctx->memset_params_host_buf.GetMappedRange();
-
-    params[0] = (uint32_t)offset;
-    params[1] = (uint32_t)size;
-    params[2] = value;
-    ctx->memset_params_host_buf.Unmap();
-
-    wgpu::BindGroupEntry entries[2];
-    entries[0].binding = 0; // binding for the buffer to memset
-    entries[0].buffer = buf;
-    entries[0].offset = 0;
-    entries[0].size = buf.GetSize();
-    entries[1].binding = 1; // binding for the parameters
-    entries[1].buffer = ctx->memset_params_dev_buf;
-    entries[1].offset = 0;
-    entries[1].size = ctx->memset_params_dev_buf.GetSize();
+            // Free the staged parameter buffers
+            ctx->param_buf_pool.free_bufs(staged_param_bufs);
+        });
+}
+
+static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
+                                           wgpu::Buffer &   buffer,
+                                           wgpu::MapMode    mode,
+                                           size_t           offset,
+                                           size_t           size) {
+    ctx->instance.WaitAny(buffer.MapAsync(mode,
+                                          offset,
+                                          size,
+                                          wgpu::CallbackMode::AllowSpontaneous,
+                                          [](wgpu::MapAsyncStatus status, wgpu::StringView message) {
+                                              if (status != wgpu::MapAsyncStatus::Success) {
+                                                  GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n",
+                                                                 message.data);
+                                              }
+                                          }),
+                          UINT64_MAX);
+}
+
+static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &                  ctx,
+                                                  wgpu::ComputePipeline &           pipeline,
+                                                  std::vector<uint32_t>             params,
+                                                  std::vector<wgpu::BindGroupEntry> bind_group_entries,
+                                                  uint32_t                          wg_x,
+                                                  bool                              submit_imm = false) {
+    webgpu_param_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
+
+    ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
+    uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
+    for (size_t i = 0; i < params.size(); i++) {
+        _params[i] = params[i];
+    };
+
+    params_bufs.host_buf.Unmap();
+
+    uint32_t params_bufs_binding_num = bind_group_entries.size();
+    bind_group_entries.push_back({ .binding = params_bufs_binding_num,
+                                   .buffer  = params_bufs.dev_buf,
+                                   .offset  = 0,
+                                   .size    = params_bufs.dev_buf.GetSize() });
 
     wgpu::BindGroupDescriptor bind_group_desc;
-    bind_group_desc.layout = ctx->memset_pipeline.GetBindGroupLayout(0);
-    bind_group_desc.entryCount = 2;
-    bind_group_desc.label = "ggml_memset";
-    bind_group_desc.entries = entries;
-    wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc);
+    bind_group_desc.layout     = pipeline.GetBindGroupLayout(0);
+    bind_group_desc.entryCount = bind_group_entries.size();
+    bind_group_desc.entries    = bind_group_entries.data();
+    wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
 
-    wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
-    encoder.CopyBufferToBuffer(
-        ctx->memset_params_host_buf, 0,
-        ctx->memset_params_dev_buf, 0,
-        ctx->memset_params_dev_buf.GetSize()
-    );
+    wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
+    encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
     wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
-    pass.SetPipeline(ctx->memset_pipeline);
+    pass.SetPipeline(pipeline);
     pass.SetBindGroup(0, bind_group);
-    size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread;
-    pass.DispatchWorkgroups(((size + 3) + bytes_per_wg - 1) / bytes_per_wg, 1, 1);
+    pass.DispatchWorkgroups(wg_x, 1, 1);
     pass.End();
     wgpu::CommandBuffer commands = encoder.Finish();
+    if (submit_imm) {
+        // Submit immediately
+        ctx->queue.Submit(1, &commands);
+        ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
+                                       [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
+                                           if (status != wgpu::QueueWorkDoneStatus::Success) {
+                                               GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
+                                                              message.data);
+                                           }
+                                           ctx->param_buf_pool.free_bufs({ params_bufs });
+                                       });
+    } else {
+        // Lock the context mutex when pushing to the staging vectors.
+        std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
+        // Enqueue commands and only submit if we have enough staged commands
+        ctx->staged_command_bufs.push_back(commands);
+        ctx->staged_param_bufs.push_back(params_bufs);
+        if (ctx->staged_command_bufs.size() == WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
+            ggml_backend_webgpu_submit_queue(ctx);
+        }
+    }
+}
 
-    ctx->queue.Submit(1, &commands);
+static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
+                                              wgpu::Buffer &   buf,
+                                              uint32_t         value,
+                                              size_t           offset,
+                                              size_t           size) {
+    std::vector<uint32_t>             params  = { (uint32_t) offset, (uint32_t) size, value };
+    std::vector<wgpu::BindGroupEntry> entries = {
+        { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
+    };
+    size_t   bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread;
+    uint32_t wg_x         = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg;
+    ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true);
 }
 
-static void ggml_backend_webgpu_wait_on_submission(webgpu_context ctx) {
-    // Wait for the queue to finish processing all commands
-    ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::WaitAnyOnly,
-        [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
-            if (status != wgpu::QueueWorkDoneStatus::Success) {
-                GGML_LOG_ERROR("ggml_webgpu: Failed to wait on queue: %s\n", message.data);
-            }
-        }),
-        UINT64_MAX
-    );
+static size_t ggml_backend_webgpu_tensor_offset(const ggml_tensor * tensor) {
+    return webgpu_tensor_offset(tensor) + tensor->view_offs;
+}
+
+static wgpu::Buffer ggml_backend_webgpu_tensor_buf(const ggml_tensor * tensor) {
+    ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
+    return ctx->buffer;
 }
 
 /** End WebGPU Actions */
@@ -218,218 +355,146 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context ctx) {
 /** GGML Backend Interface */
 
 static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
-    ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context;
+    ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
     return ctx->name.c_str();
 }
 
 static void ggml_backend_webgpu_free(ggml_backend_t backend) {
-    ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context;
+    ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
 
     // TODO: cleanup
     GGML_UNUSED(ctx);
 }
 
+static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+    size_t src_offset       = ggml_backend_webgpu_tensor_offset(src);
+    // assumes power of 2 offset alignment
+    size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
+    // align to minimum offset alignment
+    src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
+    size_t dst_offset       = ggml_backend_webgpu_tensor_offset(dst);
+    size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
+    dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
+    uint32_t              ne     = (uint32_t) ggml_nelements(dst);
+    std::vector<uint32_t> params = { ne,
+                                     (uint32_t) (src_misalignment / ggml_type_size(src->type)),
+                                     (uint32_t) (dst_misalignment / ggml_type_size(dst->type)),
+                                     // Convert byte-strides to element-strides
+                                     (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
+                                     (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
+                                     (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
+                                     (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
+                                     (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
+                                     (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
+                                     (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
+                                     (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
+                                     // Logical shape â€” same for both tensors even if permuted
+                                     (uint32_t) src->ne[0],
+                                     (uint32_t) src->ne[1],
+                                     (uint32_t) src->ne[2],
+                                     (uint32_t) src->ne[3] };
+
+    std::vector<wgpu::BindGroupEntry> entries = {
+        { .binding = 0,
+         .buffer  = ggml_backend_webgpu_tensor_buf(src),
+         .offset  = src_offset,
+         .size    = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
+                  ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) },
+        { .binding = 1,
+         .buffer  = ggml_backend_webgpu_tensor_buf(dst),
+         .offset  = dst_offset,
+         .size    = (ggml_nbytes(dst) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
+                  ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) }
+    };
+
+    size_t   max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
+    uint32_t wg_x        = (ne + max_wg_size - 1) / max_wg_size;
+    ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x);
+}
+
+static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
+    std::vector<uint32_t> params = {
+        (uint32_t) dst->ne[1],                                  // number of rows in result (M)
+        (uint32_t) dst->ne[0],                                  // number of columns in result (N)
+        (uint32_t) src0->ne[0],                                 // number of columns in src0/src1 (K)
+        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),  // stride (elements) of src0 in dimension 1
+        (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),  // stride (elements) of src1 in dimension 1
+        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),  // stride (elements) of src0 in dimension 2
+        (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),  // stride (elements) of src1 in dimension 2
+        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),  // stride (elements) of src0 in dimension 3
+        (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),  // stride (elements) of src1 in dimension 3
+        (uint32_t) src0->ne[2],                                 // batch size in dimension 2
+        (uint32_t) src0->ne[3],                                 // batch size in dimension 3
+        (uint32_t) (src1->ne[2] / src0->ne[2]),                 // broadcast in dimension 2
+        (uint32_t) (src1->ne[3] / src0->ne[3])                  // broadcast in dimension 3
+    };
+
+    std::vector<wgpu::BindGroupEntry> entries = {
+        { .binding = 0,
+         .buffer  = ggml_backend_webgpu_tensor_buf(src0),
+         .offset  = ggml_backend_webgpu_tensor_offset(src0),
+         .size    = ggml_nbytes(src0) },
+        { .binding = 1,
+         .buffer  = ggml_backend_webgpu_tensor_buf(src1),
+         .offset  = ggml_backend_webgpu_tensor_offset(src1),
+         .size    = ggml_nbytes(src1) },
+        { .binding = 2,
+         .buffer  = ggml_backend_webgpu_tensor_buf(dst),
+         .offset  = ggml_backend_webgpu_tensor_offset(dst),
+         .size    = ggml_nbytes(dst)  }
+    };
+
+    uint32_t wg_x =
+        (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
+    ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline, params, entries, wg_x);
+}
+
 // Returns true if node has enqueued work into the queue, false otherwise
-static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){
+static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
     if (ggml_is_empty(node)) {
         return false;
     }
-
     WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
 
+    ggml_tensor * src0 = node->src[0];
+    ggml_tensor * src1 = node->src[1];
 
     switch (node->op) {
-        // no-ops
+            // no-ops
         case GGML_OP_NONE:
         case GGML_OP_VIEW:
         case GGML_OP_PERMUTE:
             return false;
-
-        case GGML_OP_CPY: {
-            std::lock_guard<std::mutex> lock(ctx->mutex);
-            const ggml_tensor * src = node->src[0];
-            ggml_backend_webgpu_buffer_context * src_ctx = (ggml_backend_webgpu_buffer_context *) src->buffer->context;
-            size_t src_offset = webgpu_tensor_offset(src) + src->view_offs;
-            // assumes power of 2 offset alignment
-            size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
-            // align to minimum offset alignment
-            src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
-            ggml_backend_webgpu_buffer_context * dst_ctx = (ggml_backend_webgpu_buffer_context *) node->buffer->context;
-            size_t dst_offset = webgpu_tensor_offset(node) + node->view_offs;
-            size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
-            dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
-
-            wgpu::Device device = ctx->device;
-            ggml_backend_webgpu_map_buffer(ctx, ctx->cpy_params_host_buf,
-                wgpu::MapMode::Write, 0, ctx->cpy_params_host_buf.GetSize());
-            uint32_t * params = (uint32_t *) ctx->cpy_params_host_buf.GetMappedRange();
-            uint32_t ne = (uint32_t)ggml_nelements(node);
-            params[0] = ne;
-            params[1] = src_misalignment/ggml_type_size(src->type);
-            params[2] = dst_misalignment/ggml_type_size(node->type);
-
-            // Convert byte-strides to element-strides
-            params[3] = (uint32_t)src->nb[0]/ggml_type_size(src->type);
-            params[4] = (uint32_t)src->nb[1]/ggml_type_size(src->type);
-            params[5] = (uint32_t)src->nb[2]/ggml_type_size(src->type);
-            params[6] = (uint32_t)src->nb[3]/ggml_type_size(src->type);
-            params[7] = (uint32_t)node->nb[0]/ggml_type_size(node->type);
-            params[8] = (uint32_t)node->nb[1]/ggml_type_size(node->type);
-            params[9] = (uint32_t)node->nb[2]/ggml_type_size(node->type);
-            params[10] = (uint32_t)node->nb[3]/ggml_type_size(node->type);
-            // Logical shape â€” same for both tensors even if permuted
-            params[11] = (uint32_t)(src->ne[0]);
-            params[12] = (uint32_t)(src->ne[1]);
-            params[13] = (uint32_t)(src->ne[2]);
-            params[14] = (uint32_t)(src->ne[3]);
-
-            ctx->cpy_params_host_buf.Unmap();
-
-            wgpu::BindGroupEntry entries[3];
-            entries[0].binding = 0;
-            entries[0].buffer = src_ctx->buffer;
-            entries[0].offset = src_offset;
-            entries[0].size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
-
-            entries[1].binding = 1;
-            entries[1].buffer = dst_ctx->buffer;
-            entries[1].offset = dst_offset;
-            entries[1].size = (ggml_nbytes(node) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
-
-            entries[2].binding = 2;
-            entries[2].buffer = ctx->cpy_params_dev_buf;
-            entries[2].offset = 0;
-            entries[2].size = ctx->cpy_params_dev_buf.GetSize();
-
-            wgpu::BindGroupDescriptor bind_group_desc;
-            bind_group_desc.layout = ctx->cpy_pipeline.GetBindGroupLayout(0);
-            bind_group_desc.label = "ggml_op_cpy";
-            bind_group_desc.entryCount = 3;
-            bind_group_desc.entries = entries;
-            wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc);
-
-            wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
-            encoder.CopyBufferToBuffer(
-                ctx->cpy_params_host_buf, 0,
-                ctx->cpy_params_dev_buf, 0,
-                ctx->cpy_params_dev_buf.GetSize()
-            );
-            wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
-            pass.SetPipeline(ctx->cpy_pipeline);
-            pass.SetBindGroup(0, bind_group);
-            size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
-            pass.DispatchWorkgroups((ne + max_wg_size - 1) / max_wg_size);
-            pass.End();
-            wgpu::CommandBuffer commands = encoder.Finish();
-
-            // TODO, don't submit here, batch submissions
-            ctx->queue.Submit(1, &commands);
-            // TODO, don't wait on submission here
-            ggml_backend_webgpu_wait_on_submission(ctx);
-            return true;
-        }
-
+        case GGML_OP_CPY:
+            {
+                ggml_webgpu_cpy(ctx, src0, node);
+                break;
+            }
         case GGML_OP_MUL_MAT:
-         {
-            const ggml_tensor * src0 = node->src[0];
-            ggml_backend_webgpu_buffer_context * src0_ctx = (ggml_backend_webgpu_buffer_context *) src0->buffer->context;
-            size_t src0_offset = webgpu_tensor_offset(src0) + src0->view_offs;
-            const ggml_tensor * src1 = node->src[1];
-            ggml_backend_webgpu_buffer_context * src1_ctx = (ggml_backend_webgpu_buffer_context *) src1->buffer->context;
-            size_t src1_offset = webgpu_tensor_offset(src1) + src1->view_offs;
-            ggml_backend_webgpu_buffer_context * dst_ctx = (ggml_backend_webgpu_buffer_context *) node->buffer->context;
-
-            size_t dst_offset = webgpu_tensor_offset(node) + node->view_offs;
-
-            wgpu::Device device = ctx->device;
-
-            // map the host parameters buffer
-            ggml_backend_webgpu_map_buffer(ctx, ctx->mul_mat_params_host_buf,
-                wgpu::MapMode::Write, 0, ctx->mul_mat_params_host_buf.GetSize());
-            uint32_t * params = (uint32_t *) ctx->mul_mat_params_host_buf.GetMappedRange();
-
-            params[0] = (uint32_t)node->ne[1]; // number of rows in result (M)
-            params[1] = (uint32_t)node->ne[0]; // number of columns in result (N)
-            params[2] = (uint32_t)src0->ne[0]; // number of columns in src0/src1 (K)
-
-            params[3] = (uint32_t)src0->nb[1]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 1
-            params[4] = (uint32_t)src1->nb[1]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 1
-            params[5] = (uint32_t)src0->nb[2]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 2
-            params[6] = (uint32_t)src1->nb[2]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 2
-            params[7] = (uint32_t)src0->nb[3]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 3
-            params[8] = (uint32_t)src1->nb[3]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 3
-
-            params[9] = (uint32_t)src0->ne[2]; // batch size in dimension 2
-            params[10] = (uint32_t)src0->ne[3]; // batch size in dimension 3
-            params[11] = (uint32_t)(src1->ne[2]/src0->ne[2]); // broadcast in dimension 2
-            params[12] = (uint32_t)(src1->ne[3]/src0->ne[3]); // broadcast in dimension 3
-
-            ctx->mul_mat_params_host_buf.Unmap();
-
-            wgpu::BindGroupEntry entries[4];
-            entries[0].binding = 0;
-            entries[0].buffer = src0_ctx->buffer;
-            entries[0].offset = src0_offset;
-            entries[0].size = ggml_nbytes(src0);
-
-            entries[1].binding = 1;
-            entries[1].buffer = src1_ctx->buffer;
-            entries[1].offset = src1_offset;
-            entries[1].size = ggml_nbytes(src1);
-
-            entries[2].binding = 2;
-            entries[2].buffer = dst_ctx->buffer;
-            entries[2].offset = dst_offset;
-            entries[2].size = ggml_nbytes(node);
-
-            entries[3].binding = 3;
-            entries[3].buffer = ctx->mul_mat_params_dev_buf;
-            entries[3].offset = 0;
-            entries[3].size = ctx->mul_mat_params_dev_buf.GetSize();
-
-            wgpu::BindGroupDescriptor bind_group_desc;
-            bind_group_desc.layout = ctx->mul_mat_pipeline.GetBindGroupLayout(0);
-            bind_group_desc.entryCount = 4;
-            bind_group_desc.label = "ggml_op_mul_mat";
-            bind_group_desc.entries = entries;
-            wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc);
-
-            wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
-            encoder.CopyBufferToBuffer(
-                ctx->mul_mat_params_host_buf, 0,
-                ctx->mul_mat_params_dev_buf, 0,
-                ctx->mul_mat_params_dev_buf.GetSize()
-            );
-            wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
-            pass.SetPipeline(ctx->mul_mat_pipeline);
-            pass.SetBindGroup(0, bind_group);
-            pass.DispatchWorkgroups((node->ne[0] * node->ne[1] * node->ne[2] * node->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE);
-            pass.End();
-            wgpu::CommandBuffer commands = encoder.Finish();
-
-            // TODO, don't submit here, batch submissions
-            ctx->queue.Submit(1, &commands);
-            // TODO, don't wait on submission here
-            ggml_backend_webgpu_wait_on_submission(ctx);
-            return true;
-        }
-
+            {
+                ggml_webgpu_mul_mat(ctx, src0, src1, node);
+                break;
+            }
         default:
             return false;
     }
+    return true;
 }
 
 static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
 
     ggml_backend_webgpu_context * backend_ctx = static_cast<ggml_backend_webgpu_context *>(backend->context);
-    webgpu_context ctx = backend_ctx->webgpu_ctx;
+    webgpu_context                ctx         = backend_ctx->webgpu_ctx;
 
     for (int i = 0; i < cgraph->n_nodes; i++) {
         ggml_webgpu_encode_node(ctx, cgraph->nodes[i]);
     }
 
+    ggml_backend_webgpu_submit_queue(ctx);
+    ggml_backend_webgpu_wait_on_submission(ctx);
+
     return GGML_STATUS_SUCCESS;
 }
 
@@ -465,49 +530,69 @@ static void * ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer)
     return webgpu_ptr_base;
 }
 
-static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer,
+                                                     ggml_tensor *         tensor,
+                                                     uint8_t               value,
+                                                     size_t                offset,
+                                                     size_t                size) {
     if (size == 0) {
         WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
         return;
     }
 
-    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")");
+    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", "
+                                                                 << offset << ", " << size << ")");
 
     ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
+
     size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
+
     // This is a trick to set all bytes of a u32 to the same 1 byte value.
-    uint32_t val32 = (uint32_t)value * 0x01010101;
+    uint32_t val32 = (uint32_t) value * 0x01010101;
     ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size);
 }
 
-static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
-    ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
-    webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
+static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
+                                                  ggml_tensor *         tensor,
+                                                  const void *          data,
+                                                  size_t                offset,
+                                                  size_t                size) {
+    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", "
+                                                              << offset << ", " << size << ")");
+    ggml_backend_webgpu_buffer_context * buf_ctx    = (ggml_backend_webgpu_buffer_context *) buffer->context;
+    webgpu_context                       webgpu_ctx = buf_ctx->webgpu_ctx;
 
     size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
 
-    webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size/4)*4);
+    webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
 
     if (size % 4 != 0) {
         // If size is not a multiple of 4, we need to memset the remaining bytes
         size_t remaining_size = size % 4;
+
         // pack the remaining bytes into a uint32_t
         uint32_t val32 = 0;
+
         for (size_t i = 0; i < remaining_size; i++) {
-            ((uint8_t *)&val32)[i] = ((const uint8_t *)data)[size - remaining_size + i];
+            ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
         }
         // memset the remaining bytes
-        ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size);
+        ggml_backend_webgpu_buffer_memset(
+            webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size);
     }
 }
 
-static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
+static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
+                                                  const ggml_tensor *   tensor,
+                                                  void *                data,
+                                                  size_t                offset,
+                                                  size_t                size) {
+    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", "
+                                                              << offset << ", " << size << ")");
 
-    ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
-    webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
-    wgpu::Device device = webgpu_ctx->device;
+    ggml_backend_webgpu_buffer_context * buf_ctx    = (ggml_backend_webgpu_buffer_context *) buffer->context;
+    webgpu_context                       webgpu_ctx = buf_ctx->webgpu_ctx;
+    wgpu::Device                         device     = webgpu_ctx->device;
 
     size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
 
@@ -517,22 +602,25 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
         final_size = size + (4 - (size % 4));
     }
 
-    std::lock_guard<std::mutex> lock(webgpu_ctx->mutex);
+    std::lock_guard<std::mutex> lock(webgpu_ctx->get_tensor_mutex);
 
-    if (webgpu_ctx->get_tensor_staging_buf == nullptr ||
-        webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
+    if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
         // Create a new staging buffer if it doesn't exist or is too small
         if (webgpu_ctx->get_tensor_staging_buf) {
             webgpu_ctx->get_tensor_staging_buf.Destroy();
         }
-        ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size,
-            wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
+        ggml_webgpu_create_buffer(device,
+                                  webgpu_ctx->get_tensor_staging_buf,
+                                  final_size,
+                                  wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
+                                  "get_tensor_staging_buf");
     }
 
     // Copy the data from the buffer to the staging buffer
     wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
     encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size);
     wgpu::CommandBuffer commands = encoder.Finish();
+
     // Submit the command buffer to the queue
     webgpu_ctx->queue.Submit(1, &commands);
 
@@ -548,7 +636,6 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
 
 static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
-
     ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
     ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size);
 }
@@ -556,13 +643,13 @@ static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8
 static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
     /* .free_buffer     = */ ggml_backend_webgpu_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_webgpu_buffer_get_base,
-    /* .init_tensor     = */ NULL, // TODO: optional, needed?
+    /* .init_tensor     = */ NULL,  // TODO: optional, needed?
     /* .memset_tensor   = */ ggml_backend_webgpu_buffer_memset_tensor,
     /* .set_tensor      = */ ggml_backend_webgpu_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_webgpu_buffer_get_tensor,
-    /* .cpy_tensor      = */ NULL, // TODO: optional, implement this
+    /* .cpy_tensor      = */ NULL,  // TODO: optional, implement this
     /* .clear           = */ ggml_backend_webgpu_buffer_clear,
-    /* .reset           = */ NULL, // TODO: optional, think it coordinates with .init_tensor
+    /* .reset           = */ NULL,  // TODO: optional, think it coordinates with .init_tensor
 };
 
 /* End GGML Backend Buffer Interface */
@@ -574,13 +661,17 @@ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer
     return ctx->device_name.c_str();
 }
 
-static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
+                                                                          size_t                     size) {
     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")");
     ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
 
     wgpu::Buffer buf;
-    ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, size,
-        wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, "allocated_buffer");
+    ggml_webgpu_create_buffer(ctx->webgpu_ctx->device,
+                              buf,
+                              size,
+                              wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
+                              "allocated_buffer");
 
     ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf);
 
@@ -615,8 +706,8 @@ static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_
 static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
     ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
     // TODO: what do we actually want to return here? maxBufferSize might not be the full available memory.
-    *free = ctx->webgpu_ctx->limits.maxBufferSize;
-    *total = ctx->webgpu_ctx->limits.maxBufferSize;
+    *free                                    = ctx->webgpu_ctx->limits.maxBufferSize;
+    *total                                   = ctx->webgpu_ctx->limits.maxBufferSize;
 }
 
 static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) {
@@ -639,98 +730,93 @@ static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct
 
 static ggml_guid_t ggml_backend_webgpu_guid(void) {
     static const char * guid_str = "__ggml_webgpu :)";
-    return reinterpret_cast<ggml_guid_t>((void *)guid_str);
+    return reinterpret_cast<ggml_guid_t>((void *) guid_str);
 }
 
-static void ggml_webgpu_init_memset_pipeline(webgpu_context webgpu_ctx) {
+static void ggml_webgpu_init_memset_pipeline(webgpu_context webgpu_ctx) {
     // we use the maximum workgroup size for the memset pipeline
     size_t max_wg_size = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
     size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
     // Size the bytes_per_thread so that the largest buffer size can be handled
-    webgpu_ctx->memset_bytes_per_thread = (webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads;
+    webgpu_ctx->memset_bytes_per_thread =
+        (webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads;
     std::vector<wgpu::ConstantEntry> constants(2);
-    constants[0].key = "wg_size";
+    constants[0].key   = "wg_size";
     constants[0].value = max_wg_size;
-    constants[1].key = "bytes_per_thread";
+    constants[1].key   = "bytes_per_thread";
     constants[1].value = webgpu_ctx->memset_bytes_per_thread;
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->memset_pipeline, wgsl_memset, "memset", constants);
-    ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->memset_params_dev_buf,
-        3 * sizeof(uint32_t), // 3 parameters: buffer size, offset, value
-        wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "memset_params_dev_buf");
-    ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->memset_params_host_buf,
-        3 * sizeof(uint32_t), wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "memset_params_host_buf");
 }
 
-static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context webgpu_ctx) {
+static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context webgpu_ctx) {
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat");
-    ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_dev_buf, WEBGPU_MUL_MAT_PARAMS_SIZE,
-        wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "mul_mat_params_dev_buf");
-    ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_host_buf, WEBGPU_MUL_MAT_PARAMS_SIZE,
-        wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "mul_mat_params_host_buf");
 }
 
-static void ggml_webgpu_init_cpy_pipeline(webgpu_context webgpu_ctx) {
+static void ggml_webgpu_init_cpy_pipeline(webgpu_context webgpu_ctx) {
     std::vector<wgpu::ConstantEntry> constants(1);
-    constants[0].key = "wg_size";
+    constants[0].key   = "wg_size";
     constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
-
     ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", constants);
-    ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->cpy_params_dev_buf, WEBGPU_CPY_PARAMS_SIZE,
-        wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "cpy_params_dev_buf");
-    ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->cpy_params_host_buf, WEBGPU_CPY_PARAMS_SIZE,
-        wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "cpy_params_host_buf");
 }
 
-// TODO: Make thread safe if multiple devices are used
 static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
     GGML_UNUSED(params);
 
     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()");
 
-    ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
-    webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
-
-    std::lock_guard<std::mutex> lock(webgpu_ctx->mutex);
+    ggml_backend_webgpu_device_context * dev_ctx    = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
+    webgpu_context                       webgpu_ctx = dev_ctx->webgpu_ctx;
 
-    if (!webgpu_ctx->device_initialized) {
+    // Multiple threads may try to initialize the device
+    std::lock_guard<std::mutex> lock(webgpu_ctx->init_mutex);
+    if (!webgpu_ctx->device_init) {
         // Initialize device
-        wgpu::DeviceDescriptor dev_desc;
-        dev_desc.requiredLimits = &webgpu_ctx->limits;
-        dev_desc.requiredFeatures = webgpu_ctx->features.features;
-        dev_desc.requiredFeatureCount = webgpu_ctx->features.featureCount;
-        dev_desc.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous,
-            [](const wgpu::Device& device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
+        std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16, wgpu::FeatureName::ImplicitDeviceSynchronization };
+        wgpu::DeviceDescriptor         dev_desc;
+        dev_desc.requiredLimits       = &webgpu_ctx->limits;
+        dev_desc.requiredFeatures     = required_features.data();
+        dev_desc.requiredFeatureCount = required_features.size();
+        dev_desc.SetDeviceLostCallback(
+            wgpu::CallbackMode::AllowSpontaneous,
+            [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
                 GGML_UNUSED(device);
-                GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
-        });
+                GGML_LOG_ERROR(
+                    "ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
+            });
         dev_desc.SetUncapturedErrorCallback(
-            [](const wgpu::Device& device, wgpu::ErrorType reason, wgpu::StringView message) {
+            [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
                 GGML_UNUSED(device);
-                GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
-        });
-        webgpu_ctx->instance.WaitAny(webgpu_ctx->adapter.RequestDevice(&dev_desc, wgpu::CallbackMode::WaitAnyOnly,
-            [webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
-                if (status != wgpu::RequestDeviceStatus::Success) {
-                    GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data);
-                    return;
-                }
-                webgpu_ctx->device = device;
-            }),
-            UINT64_MAX
-        );
+                GGML_LOG_ERROR(
+                    "ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
+            });
+        webgpu_ctx->instance.WaitAny(
+            webgpu_ctx->adapter.RequestDevice(
+                &dev_desc,
+                wgpu::CallbackMode::AllowSpontaneous,
+                [webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
+                    if (status != wgpu::RequestDeviceStatus::Success) {
+                        GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data);
+                        return;
+                    }
+                    webgpu_ctx->device = std::move(device);
+                }),
+            UINT64_MAX);
         GGML_ASSERT(webgpu_ctx->device != nullptr);
 
         // Initialize (compute) queue
         webgpu_ctx->queue = webgpu_ctx->device.GetQueue();
 
+        // Create buffer pool for shader parameters
+        webgpu_ctx->param_buf_pool.init(webgpu_ctx->device);
+
         ggml_webgpu_init_memset_pipeline(webgpu_ctx);
         ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
         ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
-        webgpu_ctx->device_initialized = true;
+        webgpu_ctx->device_init = true;
     }
 
     static ggml_backend_webgpu_context backend_ctx;
-    backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
+    backend_ctx.name       = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
     backend_ctx.webgpu_ctx = webgpu_ctx;
 
     // See GGML Backend Interface section
@@ -748,14 +834,15 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
     // See GGML Backend Buffer Type Interface section
     static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
         /* .iface = */ {
-            /* .get_name         = */ ggml_backend_webgpu_buffer_type_get_name,
-            /* .alloc_buffer     = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
-            /* .get_alignment    = */ ggml_backend_webgpu_buffer_type_get_alignment,
-            /* .get_max_size     = */ ggml_backend_webgpu_buffer_type_get_max_size,
-            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
-            /* .is_host          = */ NULL, // defaults to false
+                        /* .get_name         = */ ggml_backend_webgpu_buffer_type_get_name,
+                        /* .alloc_buffer     = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
+                        /* .get_alignment    = */ ggml_backend_webgpu_buffer_type_get_alignment,
+                        /* .get_max_size     = */ ggml_backend_webgpu_buffer_type_get_max_size,
+                        /* .get_alloc_size   = */ NULL,  // defaults to ggml_nbytes
+            /* .is_host          = */ NULL,  // defaults to false
         },
-        /* .device  = */ dev,
+        /* .device  = */
+        dev,
         /* .context = */ NULL,
     };
 
@@ -764,7 +851,7 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
 
 static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
     GGML_UNUSED(dev);
-    return  buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;
+    return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;
 }
 
 static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
@@ -827,30 +914,38 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
     webgpu_context ctx = reg_ctx->webgpu_ctx;
 
     wgpu::RequestAdapterOptions options = {};
-    auto callback = [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char *message, void *userdata) {
-        if (status != wgpu::RequestAdapterStatus::Success) {
-            GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
-            return;
-        }
-        *static_cast<wgpu::Adapter *>(userdata) = adapter;
-    };
-    void *userdata = &ctx->adapter;
-    ctx->instance.WaitAny(ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::WaitAnyOnly, callback, userdata), UINT64_MAX);
+    auto                        callback =
+        [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message, void * userdata) {
+            if (status != wgpu::RequestAdapterStatus::Success) {
+                GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
+                return;
+            }
+            *static_cast<wgpu::Adapter *>(userdata) = std::move(adapter);
+        };
+    void * userdata = &ctx->adapter;
+    ctx->instance.WaitAny(
+        ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous, callback, userdata), UINT64_MAX);
     GGML_ASSERT(ctx->adapter != nullptr);
 
     ctx->adapter.GetLimits(&ctx->limits);
-    ctx->adapter.GetFeatures(&ctx->features);
 
     wgpu::AdapterInfo info{};
     ctx->adapter.GetInfo(&info);
 
     static ggml_backend_webgpu_device_context device_ctx;
-    device_ctx.webgpu_ctx = ctx;
+    device_ctx.webgpu_ctx  = ctx;
     device_ctx.device_name = GGML_WEBGPU_NAME;
     device_ctx.device_desc = std::string(info.description.data);
 
-    GGML_LOG_INFO("ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | device_desc: %s\n",
-        info.vendorID, info.vendor.data, info.architecture.data, info.deviceID, info.device.data, info.description.data);
+    GGML_LOG_INFO(
+        "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
+        "device_desc: %s\n",
+        info.vendorID,
+        info.vendor.data,
+        info.architecture.data,
+        info.deviceID,
+        info.device.data,
+        info.description.data);
 
     // See GGML Backend Device Interface section
     static ggml_backend_device device = {
@@ -861,7 +956,6 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
     return &device;
 }
 
-
 static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
     /* .get_name         = */ ggml_backend_webgpu_reg_get_name,
     /* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count,
@@ -871,23 +965,21 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
 
 /* End GGML Backend Registration Interface */
 
-// TODO: Does this need to be thread safe? Is it only called once?
 ggml_backend_reg_t ggml_backend_webgpu_reg() {
     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
 
     webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
-    webgpu_ctx->device_initialized = false;
 
     static ggml_backend_webgpu_reg_context ctx;
-    ctx.webgpu_ctx = webgpu_ctx;
-    ctx.name = GGML_WEBGPU_NAME;
+    ctx.webgpu_ctx   = webgpu_ctx;
+    ctx.name         = GGML_WEBGPU_NAME;
     ctx.device_count = 1;
 
-    wgpu::InstanceDescriptor instance_descriptor{};
-    std::vector<wgpu::InstanceFeatureName> instance_features = {wgpu::InstanceFeatureName::TimedWaitAny};
-    instance_descriptor.requiredFeatures = instance_features.data();
-    instance_descriptor.requiredFeatureCount = instance_features.size();
-    webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
+    wgpu::InstanceDescriptor               instance_descriptor{};
+    std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
+    instance_descriptor.requiredFeatures                     = instance_features.data();
+    instance_descriptor.requiredFeatureCount                 = instance_features.size();
+    webgpu_ctx->instance                                     = wgpu::CreateInstance(&instance_descriptor);
     GGML_ASSERT(webgpu_ctx->instance != nullptr);
 
     static ggml_backend_reg reg = {