]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml webgpu: Split shared state (webgpu_context) into global state and per-thread...
authorNikhil Jain <redacted>
Wed, 28 Jan 2026 04:53:36 +0000 (20:53 -0800)
committerGitHub <redacted>
Wed, 28 Jan 2026 04:53:36 +0000 (20:53 -0800)
* Squashed commit of the following:

commit b3c6bf4b0450d8d452b934df27a0fb7cb53cd755
Author: Abhijit Ramesh <redacted>
Date:   Mon Dec 1 18:29:00 2025 -0800

    ggml webgpu: fix xielu parameter passing (#11)

    The XIELU operation was incorrectly using static_cast to convert
    float parameters to uint32_t, which converted numeric values instead
    of preserving IEEE 754 bit patterns. This caused incorrect values
    to be interpreted by the GPU shader.

    * Use reinterpret_cast to preserve float bit patterns when passing
      through uint32_t params buffer
    * Update WGSL shader parameter types from u32 to f32
    * Re-enable XIELU support (was disabled due to numerical issues)

    Fixes NMSE test failures for XIELU operation on WebGPU backend.

commit 5ca9b5e49ea7cddc9ab7c8b43a11a9c76a4dff4a
Author: neha-ha <redacted>
Date:   Tue Nov 18 12:17:00 2025 -0800

    Refactored pipelines and workgroup calculations (#10)

    * refactored pipelines

    * refactored workgroup calculation

    * removed commented out block of prior maps

    * Clean up ceiling division pattern

    ---------

Co-authored-by: Neha Abbas <redacted>
Co-authored-by: Reese Levine <redacted>
Author: James Contini <redacted>
Date:   Wed Oct 29 23:13:06 2025 -0700

    formatted embed wgsl and ggml-webgpu.cpp

commit e1f6baea31645e5d96ad53664acae856f74b96f4
Author: James Contini <redacted>
Date:   Wed Oct 29 23:08:37 2025 -0700

    implemented REPL_Template support and removed bug in unary operators kernel

commit 8c70b8fece445cdc9a8c660dbddbf201e52da2bb
Author: James Contini <redacted>
Date:   Wed Oct 15 16:14:20 2025 -0700

    responded and dealt with PR comments

commit f9282c660c10dec4487d434549bdb707a9cd9f37
Author: James Contini <redacted>
Date:   Sun Oct 12 13:41:41 2025 -0700

    removed unnecesarry checking if node->src[1] exists for unary operators

commit 4cf28d7dec41c29186d66152735b244c5699f9dc
Author: James Contini <redacted>
Date:   Sun Oct 12 13:32:45 2025 -0700

    All operators (inlcluding xielu) working

commit 74c6add1761a59d2c2ff60b60e8ad3c8300f6d3e
Author: James Contini <redacted>
Date:   Fri Oct 10 13:16:48 2025 -0700

    fixed autoconfig

commit 362749910be4f0120c8ffb21ceddeb7d2c088e51
Author: James Contini <redacted>
Date:   Fri Oct 10 13:10:46 2025 -0700

    removed vestigial files

commit cb0858333785757804c5104e59c4981843207c16
Author: James Contini <redacted>
Date:   Fri Oct 10 12:59:32 2025 -0700

    abides by editor-config

commit 5360e2852a4b51197d7d67d0a5d42e908b02d7ed
Author: James Contini <redacted>
Date:   Fri Oct 10 12:45:57 2025 -0700

    rms_norm double declaration bug atoned

commit 7b09baa4aa53711be5a126043670cc182c78bfcd
Merge: 8a6ec843 74b8fc17
Author: James Contini <redacted>
Date:   Fri Oct 10 11:50:03 2025 -0700

    resolving merge conflicts

commit 8a6ec843a50ab82f8cef59b4558eb63f318ba02d
Author: James Contini <redacted>
Date:   Wed Oct 8 18:06:47 2025 -0700

    unary operators pass ggml tests

commit c3ae38278a2db236adc5912c9140e4f0d63f2c19
Author: James Contini <redacted>
Date:   Wed Oct 1 16:22:40 2025 -0700

    neg passes backend test

commit aa1c9b2f8877a405470ca56709c42a1fd43713de
Author: James Contini <redacted>
Date:   Tue Sep 30 23:55:27 2025 -0700

    neg f16xf32xip builds and runs, havent actually ran a model that uses neg kernel yet though

Co-authored-by: James Contini <redacted>
Co-authored-by: Neha Abbas <redacted>
Co-authored-by: Abhijit Ramesh <redacted>
* Remove extra code and format

* Add ops documentation (finally)

* ggml webgpu: add SOFTPLUS unary operator

Implements SOFTPLUS (log(1 + exp(x))) with f16/f32 support. Uses f32
precision for intermediate calculations to prevent f16 overflow.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support
* Follow Vulkan backend numerical stability pattern

* ggml webgpu: add EXPM1 unary operator

Implements EXPM1 (exp(x) - 1) with f16/f32 support.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support

* ggml webgpu: add FLOOR unary operator

Implements FLOOR (rounds down to nearest integer) with f16/f32 support.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support

* ggml webgpu: add CEIL unary operator

Implements CEIL (rounds up to nearest integer) with f16/f32 support.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support

* ggml webgpu: add ROUND unary operator

Implements ROUND (rounds to nearest integer) with f16/f32 support.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support

* ggml webgpu: add TRUNC unary operator

Implements TRUNC (truncates towards zero) with f16/f32 support.

* Add shader implementation and 4 variants (f32/f16, inplace/non-inplace)
* Register pipelines and device support

* docs : update WebGPU support for unary operators (FLOOR, CEIL, ROUND, TRUNC, EXPM1, SOFTPLUS)

* Updates to webgpu get_memory

* Move shared state (webgpu_context) and device creation out of registration context, device context, and buffer context, and move into backend context

* Small cleanup

* Move Instance, Device, Adapter, Device creation, and capabilities to global state while moving Queue, pipelines, and buffers to per-thread state.

* Cleanups

* More cleanup

* Move staging_buf mutex to global context

* Resolve merge

* Resolve merge

* Resolve merge

* Clean up merge errors, delete forward declaration, and run clang-format

* Rename device_init to backend_init

* Move webgpu_context to backend_context

* Move buffer context members into global context and refactor function calls

* Run clang-format

* Remove commends

* Move parameter buffers to per-thread, add single memset_tensor param buf

* Fix CI compilation issue

* Fix builds for emscripten not supporting subgroups

* cleanup

* cleanup

---------

Co-authored-by: Reese Levine <redacted>
ggml/src/ggml-webgpu/ggml-webgpu.cpp

index 584cea7698b8a011f814e889e794b761cb6e8d9b..22e2bfeb4ce498f8666eadcab350e0541e1ba3b5 100644 (file)
@@ -47,7 +47,6 @@
         double cpu_total_time_##id =                                                                      \
             std::chrono::duration<double, std::milli>(cpu_total_end_##id - cpu_total_start_##id).count(); \
         (ctx)->cpu_time_ms[#id] += cpu_total_time_##id;
-
 // fine-grained timing (not included in totals)
 #    define WEBGPU_CPU_PROFILE_DETAIL_START(id) auto cpu_detail_start_##id = std::chrono::high_resolution_clock::now();
 
 #define WEBGPU_MAX_WG_SIZE 288
 
 #define WEBGPU_MUL_MAT_WG_SIZE               256
-#define WEBGPU_NUM_PARAM_BUFS                32u
+#define WEBGPU_NUM_PARAM_BUFS                16u
 #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE     8u
 #define WEBGPU_WAIT_ANY_TIMEOUT_MS           0
 // Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool
 #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD  WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
 #define WEBGPU_PARAMS_BUF_SIZE_BYTES         128  // enough for 32 parameters
-#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS       32
+#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS       16
 #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
 #define WEBGPU_STORAGE_BUF_BINDING_MULT      4  // a storage buffer binding size must be a multiple of 4
 
@@ -267,30 +266,67 @@ struct webgpu_command {
 #endif
 };
 
-// All the base objects needed to run operations on a WebGPU device
-struct webgpu_context_struct {
+struct webgpu_capabilities_base {
+    wgpu::Limits limits;
+    bool         supports_subgroup_matrix = false;
+
+    uint32_t sg_mat_m = 0;
+    uint32_t sg_mat_n = 0;
+    uint32_t sg_mat_k = 0;
+
+    uint32_t subgroup_size     = 0;
+    uint32_t max_subgroup_size = 0;
+    size_t   memset_bytes_per_thread;
+};
+
+// Stores global webgpu members
+struct webgpu_global_context_struct {
     wgpu::Instance instance;
     wgpu::Adapter  adapter;
     wgpu::Device   device;
     wgpu::Queue    queue;
-    wgpu::Limits   limits;
 
-    uint32_t max_subgroup_size;
+    webgpu_capabilities_base capabilities;
+    // Shared buffer to move data from device to host
+    wgpu::Buffer             get_tensor_staging_buf;
+    // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches.
+    std::recursive_mutex     mutex;
 
-    bool     supports_subgroup_matrix = false;
-    uint32_t sg_mat_m;
-    uint32_t sg_mat_n;
-    uint32_t sg_mat_k;
+    webgpu_buf_pool                memset_buf_pool;
+    std::map<int, webgpu_pipeline> memset_pipelines;  // variant or type index
+    std::atomic_uint               inflight_threads = 0;
 
-    std::recursive_mutex mutex;
-    std::atomic_uint     inflight_threads = 0;
+#ifdef GGML_WEBGPU_CPU_PROFILE
+    // Profiling: labeled CPU time in ms (total)
+    std::unordered_map<std::string, double> cpu_time_ms;
+    // Profiling: detailed CPU time in ms
+    std::unordered_map<std::string, double> cpu_detail_ms;
+#endif
 
-    webgpu_buf_pool param_buf_pool;
-    webgpu_buf_pool set_rows_error_buf_pool;
+#ifdef GGML_WEBGPU_GPU_PROFILE
+    // Profiling: per-shader GPU time in ms
+    std::unordered_map<std::string, double> shader_gpu_time_ms;
+    // Profiling: pool of timestamp query buffers (one per operation)
+    webgpu_gpu_profile_buf_pool             timestamp_query_buf_pool;
+#endif
+
+#ifdef GGML_WEBGPU_DEBUG
+    wgpu::Buffer debug_host_buf;
+    wgpu::Buffer debug_dev_buf;
+#endif
+};
+
+typedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context;
+
+// All the base objects needed to run operations on a WebGPU device
+struct webgpu_context_struct {
+    // Points to global instances owned by ggml_backend_webgpu_reg_context
+    webgpu_global_context global_ctx;
 
     pre_wgsl::Preprocessor p;
 
-    std::map<int, webgpu_pipeline> memset_pipelines;                                 // variant or type index
+    webgpu_buf_pool param_buf_pool;
+    webgpu_buf_pool set_rows_error_buf_pool;
 
     std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines;  // src0_type, src1_type, vectorized
     std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
@@ -326,57 +362,42 @@ struct webgpu_context_struct {
 
     size_t memset_bytes_per_thread;
 
-    // Staging buffer for reading data from the GPU
-    wgpu::Buffer get_tensor_staging_buf;
-
-#ifdef GGML_WEBGPU_DEBUG
-    wgpu::Buffer debug_host_buf;
-    wgpu::Buffer debug_dev_buf;
-#endif
-
-#ifdef GGML_WEBGPU_CPU_PROFILE
-    // Profiling: labeled CPU time in ms (total)
-    std::unordered_map<std::string, double> cpu_time_ms;
-    // Profiling: detailed CPU time in ms
-    std::unordered_map<std::string, double> cpu_detail_ms;
-#endif
-
-#ifdef GGML_WEBGPU_GPU_PROFILE
-    // Profiling: per-shader GPU time in ms
-    std::unordered_map<std::string, double> shader_gpu_time_ms;
-    // Profiling: pool of timestamp query buffers (one per operation)
-    webgpu_gpu_profile_buf_pool             timestamp_query_buf_pool;
-#endif
 };
 
 typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
 
+// Metadata required for the ggml backend registration/discovery interface
 struct ggml_backend_webgpu_reg_context {
-    webgpu_context webgpu_ctx;
-    size_t         device_count;
-    const char *   name;
+    // Since the Instance is a global entrypoint into the WebGPU API, it lives here
+    webgpu_global_context webgpu_global_ctx;
+    size_t                device_count;
+    const char *          name;
 };
 
+// Per-device struct for the global logical device interface
 struct ggml_backend_webgpu_device_context {
-    webgpu_context webgpu_ctx;
-    std::string    device_name;
-    std::string    device_desc;
+    webgpu_global_context webgpu_global_ctx;
+    std::string           device_name;
+    std::string           device_desc;
 };
 
+// Per-thread data required to actually run WebGPU operations in a backend instance
 struct ggml_backend_webgpu_context {
-    webgpu_context webgpu_ctx;
-    std::string    name;
+    webgpu_context        webgpu_ctx;
+    std::once_flag        init_once;
+    std::string           name;
 };
 
+// Per-thread data related to buffers
 struct ggml_backend_webgpu_buffer_context {
-    webgpu_context webgpu_ctx;
-    wgpu::Buffer   buffer;
-    std::string    label;
+    wgpu::Buffer          buffer;
+    std::string           label;
+    webgpu_global_context global_ctx;
 
-    ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf, std::string lbl) :
-        webgpu_ctx(std::move(ctx)),
+    ggml_backend_webgpu_buffer_context(wgpu::Buffer buf, std::string lbl, webgpu_global_context global_ctx_) :
         buffer(std::move(buf)),
-        label(std::move(lbl)) {}
+        label(std::move(lbl)),
+        global_ctx(std::move(global_ctx_)) {}
 };
 
 /* WebGPU object initializations */
@@ -444,7 +465,7 @@ static void ggml_webgpu_create_buffer(wgpu::Device &    device,
 /** WebGPU Actions */
 
 // Wait for the queue to finish processing all submitted work
-static void ggml_backend_webgpu_wait(webgpu_context &                         ctx,
+static void ggml_backend_webgpu_wait(webgpu_global_context &                  ctx,
                                      std::vector<webgpu_submission_futures> & futures,
                                      bool                                     block = true) {
     // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
@@ -476,11 +497,11 @@ static void ggml_backend_webgpu_wait(webgpu_context &                         ct
     }
 }
 
-static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
-                                           wgpu::Buffer &   buffer,
-                                           wgpu::MapMode    mode,
-                                           size_t           offset,
-                                           size_t           size) {
+static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,
+                                           wgpu::Buffer &          buffer,
+                                           wgpu::MapMode           mode,
+                                           size_t                  offset,
+                                           size_t                  size) {
     ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
                                           [](wgpu::MapAsyncStatus status, wgpu::StringView message) {
                                               if (status != wgpu::MapAsyncStatus::Success) {
@@ -495,7 +516,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
 // This function adds debugging information to shaders, as WebGPU does not support printing directly.
 // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
 // debug statements in the shader, and then call this function after encoding the commands and submitting them.
-static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
+static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
     wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
     encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
     wgpu::CommandBuffer commands = encoder.Finish();
@@ -507,7 +528,10 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
 }
 #endif
 
-static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, std::vector<webgpu_command> commands) {
+static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_context       ctx,
+                                                            std::vector<webgpu_command> commands,
+                                                            webgpu_buf_pool &           param_buf_pool,
+                                                            webgpu_buf_pool * set_rows_error_buf_pool = nullptr) {
     std::vector<wgpu::CommandBuffer> command_buffers;
     std::vector<webgpu_pool_bufs>    params_bufs;
     std::vector<webgpu_pool_bufs>    set_rows_error_bufs;
@@ -528,19 +552,19 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx,
 
     wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
         wgpu::CallbackMode::AllowSpontaneous,
-        [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
+        [&param_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
             if (status != wgpu::QueueWorkDoneStatus::Success) {
                 GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
             }
             // Free the staged buffers
-            ctx->param_buf_pool.free_bufs(params_bufs);
+            param_buf_pool.free_bufs(params_bufs);
         });
     futures.push_back({ p_f });
 
     for (const auto & bufs : set_rows_error_bufs) {
         wgpu::Future f = bufs.host_buf.MapAsync(
             wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
-            [ctx, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
+            [set_rows_error_buf_pool, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
                 if (status != wgpu::MapAsyncStatus::Success) {
                     GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
                 } else {
@@ -549,7 +573,9 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx,
                         GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
                     }
                     // We can't unmap in here due to WebGPU reentrancy limitations.
-                    ctx->set_rows_error_buf_pool.free_bufs({ bufs });
+                    if (set_rows_error_buf_pool) {
+                        set_rows_error_buf_pool->free_bufs({ bufs });
+                    }
                 }
             });
         futures.push_back({ f });
@@ -581,7 +607,8 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx,
 }
 
 static webgpu_command ggml_backend_webgpu_build_multi(
-    webgpu_context &                                       ctx,
+    webgpu_global_context &                                ctx,
+    webgpu_buf_pool &                                      param_buf_pool,
     const std::vector<webgpu_pipeline> &                   pipelines,
     const std::vector<std::vector<uint32_t>> &             params_list,
     const std::vector<std::vector<wgpu::BindGroupEntry>> & bind_group_entries_list,
@@ -595,7 +622,7 @@ static webgpu_command ggml_backend_webgpu_build_multi(
     std::vector<wgpu::BindGroup>  bind_groups;
 
     for (size_t i = 0; i < pipelines.size(); i++) {
-        webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
+        webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs();
 
         ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0,
                                        params_bufs.host_buf.GetSize());
@@ -672,34 +699,37 @@ static webgpu_command ggml_backend_webgpu_build_multi(
     return result;
 }
 
-static webgpu_command ggml_backend_webgpu_build(webgpu_context &                  ctx,
+static webgpu_command ggml_backend_webgpu_build(webgpu_global_context &           ctx,
+                                                webgpu_buf_pool &                 param_buf_pool,
                                                 webgpu_pipeline &                 pipeline,
                                                 std::vector<uint32_t>             params,
                                                 std::vector<wgpu::BindGroupEntry> bind_group_entries,
                                                 uint32_t                          wg_x,
                                                 uint32_t                          wg_y                = 1,
                                                 std::optional<webgpu_pool_bufs>   set_rows_error_bufs = std::nullopt) {
-    return ggml_backend_webgpu_build_multi(ctx,
+    return ggml_backend_webgpu_build_multi(ctx, param_buf_pool,
                                            {
                                                pipeline
     },
                                            { params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs);
 }
 
-static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
-                                              wgpu::Buffer &   buf,
-                                              uint32_t         value,
-                                              size_t           offset,
-                                              size_t           size) {
+static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
+                                              wgpu::Buffer &          buf,
+                                              uint32_t                value,
+                                              size_t                  offset,
+                                              size_t                  size) {
     std::vector<uint32_t>             params  = { (uint32_t) offset, (uint32_t) size, value };
     std::vector<wgpu::BindGroupEntry> entries = {
         { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
     };
-    size_t   bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->memset_bytes_per_thread;
+    size_t   bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread;
     uint32_t wg_x         = CEIL_DIV(size + 3, bytes_per_wg);
 
-    webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_pipelines[0], params, entries, wg_x);
-    std::vector<webgpu_submission_futures> futures = { ggml_backend_webgpu_submit(ctx, { command }) };
+    webgpu_command command =
+        ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
+    std::vector<webgpu_submission_futures> futures = { ggml_backend_webgpu_submit(ctx, { command },
+                                                                                  ctx->memset_buf_pool) };
     ggml_backend_webgpu_wait(ctx, futures);
 }
 
@@ -720,19 +750,19 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
 #ifdef GGML_WEBGPU_CPU_PROFILE
     std::cout << "\n[ggml_webgpu cpu profiling summary]\n";
     double total_cpu = 0.0;
-    for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) {
+    for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) {
         total_cpu += kv.second;
     }
     std::cout << "ggml_webgpu: total cpu time: " << total_cpu << " ms\n";
     std::cout << "ggml_webgpu: cpu breakdown:\n";
-    for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) {
+    for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) {
         double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
         std::cout << "ggml_webgpu:  " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
     }
-    if (ctx->webgpu_ctx->cpu_detail_ms.size() > 0) {
+    if (ctx->webgpu_ctx->global_ctx->cpu_detail_ms.size() > 0) {
         std::cout << "ggml_webgpu: cpu detailed breakdown:\n";
     }
-    for (const auto & kv : ctx->webgpu_ctx->cpu_detail_ms) {
+    for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_detail_ms) {
         double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
         std::cout << "ggml_webgpu:  " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
     }
@@ -741,12 +771,12 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
 #ifdef GGML_WEBGPU_GPU_PROFILE
     std::cout << "\n[ggml_webgpu gpu profiling summary]\n";
     double total_gpu = 0.0;
-    for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) {
+    for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
         total_gpu += kv.second;
     }
     std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n";
     std::cout << "\nggml_webgpu: gpu breakdown:\n";
-    for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) {
+    for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
         double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
         std::cout << "ggml_webgpu:  " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
     }
@@ -772,12 +802,12 @@ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
 
 static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) {
     size_t offset = ggml_webgpu_tensor_offset(t);
-    return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
+    return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
 }
 
 static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
     size_t offset = ggml_webgpu_tensor_offset(t);
-    return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
+    return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
 }
 
 static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {
@@ -818,28 +848,30 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g
     };
 
     uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
-    return ggml_backend_webgpu_build(ctx, ctx->cpy_pipelines[src->type][dst->type], params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type],
+                                     params, entries, wg_x);
 }
 
 static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
     const bool circular = ggml_get_op_params_i32(dst, 8) != 0;
 
     ggml_webgpu_pad_pipeline_key       pipeline_key   = { .circular = circular };
-    ggml_webgpu_pad_shader_lib_context shader_lib_ctx = { .key = pipeline_key,
-                                                          .max_wg_size =
-                                                              ctx->limits.maxComputeInvocationsPerWorkgroup };
+    ggml_webgpu_pad_shader_lib_context shader_lib_ctx = {
+        .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
+    };
 
     webgpu_pipeline pipeline;
     {
         // TODO: remove guard once pipeline caches are per-thread
-        std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
+        std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
         auto                                  it = ctx->pad_pipelines.find(pipeline_key);
         if (it != ctx->pad_pipelines.end()) {
             pipeline = it->second;
         } else {
             ggml_webgpu_processed_shader processed =
                 ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx);
-            pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+            pipeline =
+                ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
             pipeline.context = processed.decisions;
             ctx->pad_pipelines.emplace(pipeline_key, pipeline);
         }
@@ -891,7 +923,7 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g
     };
 
     uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size);
-    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
 static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
@@ -907,21 +939,22 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
                                               .vec4     = src->ne[0] % 4 == 0,
                                               .i64_idx  = idx->type == GGML_TYPE_I64 };
 
-    ggml_webgpu_set_rows_shader_lib_context shader_lib_ctx = { .key = key,
-                                                               .max_wg_size =
-                                                                   ctx->limits.maxComputeInvocationsPerWorkgroup };
+    ggml_webgpu_set_rows_shader_lib_context shader_lib_ctx = {
+        .key = key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
+    };
 
     webgpu_pipeline pipeline;
     // TODO: remove guard once pipeline caches are per-thread
     {
-        std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
+        std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
         auto                                  it = ctx->set_rows_pipelines.find(key);
         if (it != ctx->set_rows_pipelines.end()) {
             pipeline = it->second;
         } else {
             ggml_webgpu_processed_shader processed =
                 ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx);
-            pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+            pipeline =
+                ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
             pipeline.context = processed.decisions;
             ctx->set_rows_pipelines.emplace(key, pipeline);
         }
@@ -981,7 +1014,8 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
         threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
     }
     uint32_t wg_x = CEIL_DIV(threads, decisions.wg_size);
-    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1, error_bufs);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1,
+                                     error_bufs);
 }
 
 static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
@@ -1023,7 +1057,7 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
 
     uint32_t        vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0;
     webgpu_pipeline pipeline   = ctx->get_rows_pipelines[src->type][vectorized];
-    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
 static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
@@ -1098,19 +1132,21 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
             uint32_t batches       = dst->ne[2] * dst->ne[3];
             uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG);
             uint32_t total_wg      = output_groups * batches;
-            wg_x                   = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension;
-            wg_y                   = CEIL_DIV(total_wg, ctx->limits.maxComputeWorkgroupsPerDimension);
+            wg_x                   = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
+            wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
         } else {
             pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized];
             uint32_t wg_m;
             uint32_t wg_n;
 #ifndef __EMSCRIPTEN__
-            if (ctx->supports_subgroup_matrix) {
+            if (ctx->global_ctx->capabilities.supports_subgroup_matrix) {
                 // The total number of subgroups/workgroups needed per matrix.
-                uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->sg_mat_m;
+                uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M *
+                                        ctx->global_ctx->capabilities.sg_mat_m;
                 wg_m                  = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
-                uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->sg_mat_n;
-                wg_n                  = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
+                uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N *
+                                        ctx->global_ctx->capabilities.sg_mat_n;
+                wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
             } else {
 #endif
                 uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
@@ -1124,9 +1160,10 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
             wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
         }
     }
-    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
 }
 
+#ifndef __EMSCRIPTEN__
 static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
                                              ggml_tensor *    Q,
                                              ggml_tensor *    K,
@@ -1210,8 +1247,8 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
                         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
                         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
 
-    bool kv_direct =
-        (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
+    bool kv_direct = (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) &&
+                     (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
 
     ggml_webgpu_flash_attn_pipeline_key key = {
         .kv_type            = K->type,
@@ -1223,25 +1260,27 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
         .uses_logit_softcap = logit_softcap != 0.0f,
     };
 
-    webgpu_pipeline                         pipeline;
+    webgpu_pipeline pipeline;
     // TODO: remove guard once pipeline caches are per-thread
     {
-        std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
+        std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
         auto                                  it = ctx->flash_attn_pipelines.find(key);
         if (it != ctx->flash_attn_pipelines.end()) {
-            pipeline  = it->second;
+            pipeline = it->second;
         } else {
-            ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .key      = key,
-                                                                         .sg_mat_m = ctx->sg_mat_m,
-                                                                         .sg_mat_n = ctx->sg_mat_n,
-                                                                         .sg_mat_k = ctx->sg_mat_k,
-                                                                         .wg_mem_limit_bytes =
-                                                                             ctx->limits.maxComputeWorkgroupStorageSize,
-                                                                         .max_subgroup_size = ctx->max_subgroup_size };
+            ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
+                .key                = key,
+                .sg_mat_m           = ctx->global_ctx->capabilities.sg_mat_m,
+                .sg_mat_n           = ctx->global_ctx->capabilities.sg_mat_n,
+                .sg_mat_k           = ctx->global_ctx->capabilities.sg_mat_k,
+                .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
+                .max_subgroup_size  = ctx->global_ctx->capabilities.max_subgroup_size
+            };
 
             ggml_webgpu_processed_shader processed =
                 ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
-            pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+            pipeline =
+                ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
             pipeline.context = processed.decisions;
             ctx->flash_attn_pipelines.emplace(key, pipeline);
         }
@@ -1250,11 +1289,11 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
     ggml_webgpu_flash_attn_shader_decisions decisions =
         *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
 
-
     uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile);
     uint32_t wg_x        = wg_per_head * Q->ne[2] * Q->ne[3];  // wg per head * number of heads * number of batches
-    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
+#endif
 
 static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
     bool is_unary = dst->op == GGML_OP_UNARY;
@@ -1264,21 +1303,22 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s
     ggml_webgpu_unary_pipeline_key pipeline_key = {
         .type = dst->type, .op = op, .is_unary = is_unary, .inplace = inplace
     };
-    ggml_webgpu_unary_shader_lib_context shader_lib_ctx = { .key = pipeline_key,
-                                                            .max_wg_size =
-                                                                ctx->limits.maxComputeInvocationsPerWorkgroup };
+    ggml_webgpu_unary_shader_lib_context shader_lib_ctx = {
+        .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
+    };
 
     webgpu_pipeline pipeline;
     {
         // TODO: remove guard once pipeline caches are per-thread
-        std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
+        std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
         auto                                  it = ctx->unary_pipelines.find(pipeline_key);
         if (it != ctx->unary_pipelines.end()) {
             pipeline = it->second;
         } else {
             ggml_webgpu_processed_shader processed =
                 ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx);
-            pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+            pipeline =
+                ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
             pipeline.context = processed.decisions;
             ctx->unary_pipelines.emplace(pipeline_key, pipeline);
         }
@@ -1346,7 +1386,7 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s
     }
 
     uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size);
-    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
 static webgpu_command ggml_webgpu_binary_op(webgpu_context &  ctx,
@@ -1391,7 +1431,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context &  ctx,
     }
 
     uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
-    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
 static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
@@ -1426,7 +1466,8 @@ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * s
                             .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
     }
 
-    return ggml_backend_webgpu_build(ctx, ctx->rms_norm_pipelines[inplace], params, entries, ggml_nrows(src));
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params,
+                                     entries, ggml_nrows(src));
 }
 
 static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
@@ -1513,7 +1554,7 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
 
     webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace];
     uint32_t        wg_x     = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
-    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
 static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
@@ -1565,7 +1606,7 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0,
 
     webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split];
     uint32_t        wg_x     = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
-    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
 static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
@@ -1602,7 +1643,8 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src,
     }
 
     uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
-    return ggml_backend_webgpu_build(ctx, ctx->scale_pipelines[inplace], params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->scale_pipelines[inplace], params,
+                                     entries, wg_x);
 }
 
 static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
@@ -1674,7 +1716,8 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
                             .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
     }
 
-    return ggml_backend_webgpu_build(ctx, ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool,
+                                     ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
                                      ggml_nrows(dst));
 }
 
@@ -1696,25 +1739,26 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src
 
     ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
         .vec4        = src->ne[0] % 4 == 0,
-        .max_wg_size = ctx->limits.maxComputeInvocationsPerWorkgroup,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
     };
 
     webgpu_pipeline pipeline;
     {
         // TODO: remove guard once pipeline caches are per-thread
-        std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
+        std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
         auto                                  it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4);
         if (it != ctx->argmax_pipelines.end()) {
             pipeline = it->second;
         } else {
             ggml_webgpu_processed_shader processed =
                 ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax");
-            pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+            pipeline =
+                ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
             ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline);
         }
     }
     uint32_t wg_x = ggml_nelements(dst);
-    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
 static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
@@ -1722,13 +1766,13 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr
     // ascending order is 0, descending order is 1
     const int32_t order    = is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(dst, 0);
 
-    ggml_webgpu_argsort_shader_lib_context shader_lib_ctx = { .max_wg_size =
-                                                                  ctx->limits.maxComputeInvocationsPerWorkgroup,
-                                                              .wg_mem_limit_bytes =
-                                                                  ctx->limits.maxComputeWorkgroupStorageSize,
-                                                              .order = order };
+    ggml_webgpu_argsort_shader_lib_context shader_lib_ctx = {
+        .max_wg_size        = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+        .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
+        .order              = order
+    };
 
-    std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
+    std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
     webgpu_pipeline                       argsort_pipeline;
     auto                                  it = ctx->argsort_pipelines.find(order);
     if (it != ctx->argsort_pipelines.end()) {
@@ -1736,7 +1780,8 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr
     } else {
         ggml_webgpu_processed_shader processed =
             ggml_webgpu_preprocess_argsort_shader(ctx->p, wgsl_argsort, shader_lib_ctx);
-        argsort_pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+        argsort_pipeline =
+            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
         argsort_pipeline.context = processed.decisions;
         ctx->argsort_pipelines.emplace(order, argsort_pipeline);
     }
@@ -1751,7 +1796,7 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr
         ggml_webgpu_processed_shader processed =
             ggml_webgpu_preprocess_argsort_merge_shader(ctx->p, wgsl_argsort_merge, shader_lib_ctx);
         argsort_merge_pipeline =
-            ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
         argsort_merge_pipeline.context = processed.decisions;
         ctx->argsort_merge_pipelines.emplace(order, argsort_merge_pipeline);
     }
@@ -1780,9 +1825,10 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr
 
     const bool start_in_tmp = (merge_passes % 2) == 1;
 
-    const size_t dst_offset       = ggml_webgpu_tensor_offset(dst);
-    const size_t idx_nbytes       = out_ne0 * ggml_nrows(dst) * sizeof(int32_t);
-    const size_t tmp_offset       = ROUNDUP_POW2(dst_offset + idx_nbytes, ctx->limits.minStorageBufferOffsetAlignment);
+    const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
+    const size_t idx_nbytes = out_ne0 * ggml_nrows(dst) * sizeof(int32_t);
+    const size_t tmp_offset =
+        ROUNDUP_POW2(dst_offset + idx_nbytes, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
     const size_t tmp_binding_size = ROUNDUP_POW2(idx_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT);
     const size_t dst_binding_size =
         ROUNDUP_POW2(idx_nbytes + ggml_webgpu_tensor_misalignment(ctx, dst), WEBGPU_STORAGE_BUF_BINDING_MULT);
@@ -1813,10 +1859,10 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr
     };
 
     const uint32_t                    total_wg_init = npr * nrows;
-    const uint32_t                    max_wg        = ctx->limits.maxComputeWorkgroupsPerDimension;
-    const uint32_t                    wg_x_init     = std::min(total_wg_init, max_wg);
-    const uint32_t                    wg_y_init     = CEIL_DIV(total_wg_init, wg_x_init);
-    std::vector<wgpu::BindGroupEntry> init_entries  = {
+    const uint32_t                    max_wg    = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
+    const uint32_t                    wg_x_init = std::min(total_wg_init, max_wg);
+    const uint32_t                    wg_y_init = CEIL_DIV(total_wg_init, wg_x_init);
+    std::vector<wgpu::BindGroupEntry> init_entries = {
         { .binding = 0,
          .buffer  = ggml_webgpu_tensor_buf(src),
          .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
@@ -1830,7 +1876,8 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr
     workgroups_list.push_back({ wg_x_init, wg_y_init });
 
     if (merge_passes == 0) {
-        return ggml_backend_webgpu_build_multi(ctx, pipelines, params_list, entries_list, workgroups_list);
+        return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list,
+                                               entries_list, workgroups_list);
     }
 
     bool     in_is_tmp = start_in_tmp;
@@ -1891,7 +1938,8 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr
         in_is_tmp = !in_is_tmp;
     }
 
-    return ggml_backend_webgpu_build_multi(ctx, pipelines, params_list, entries_list, workgroups_list);
+    return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list,
+                                           workgroups_list);
 }
 
 static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
@@ -1912,24 +1960,25 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src
 
     ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
         .vec4        = false,
-        .max_wg_size = ctx->limits.maxComputeInvocationsPerWorkgroup,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
     };
     webgpu_pipeline pipeline;
     // TODO: remove guard once pipeline caches are per-thread
     {
-        std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
+        std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
         auto                                  it = ctx->cumsum_pipelines.find(1);
         if (it != ctx->cumsum_pipelines.end()) {
             pipeline = it->second;
         } else {
             ggml_webgpu_processed_shader processed =
                 ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum");
-            pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+            pipeline =
+                ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
             ctx->cumsum_pipelines.emplace(1, pipeline);
         }
     }
     uint32_t wg_x = ggml_nrows(dst);
-    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
 static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
@@ -1956,25 +2005,26 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s
 
     ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
         .vec4        = false,
-        .max_wg_size = ctx->limits.maxComputeInvocationsPerWorkgroup,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
     };
 
     webgpu_pipeline pipeline;
     {
         // TODO: remove guard once pipeline caches are per-thread
-        std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
+        std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
         auto                                  it = ctx->sum_rows_pipelines.find(1);
         if (it != ctx->sum_rows_pipelines.end()) {
             pipeline = it->second;
         } else {
             ggml_webgpu_processed_shader processed =
                 ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows");
-            pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+            pipeline =
+                ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
             ctx->sum_rows_pipelines.emplace(1, pipeline);
         }
     }
     uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst);
-    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
 // Returns the encoded command, or std::nullopt if the operation is a no-op
@@ -2009,7 +2059,11 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
         case GGML_OP_MUL_MAT:
             return ggml_webgpu_mul_mat(ctx, src0, src1, node);
         case GGML_OP_FLASH_ATTN_EXT:
+#ifndef __EMSCRIPTEN__
             return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
+#else
+            return std::nullopt;
+#endif
         case GGML_OP_ADD:
             {
                 int inplace = ggml_webgpu_tensor_equal(src0, node);
@@ -2070,12 +2124,12 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
 static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
 
-    ggml_backend_webgpu_context * backend_ctx = static_cast<ggml_backend_webgpu_context *>(backend->context);
+    ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context;
     webgpu_context                ctx         = backend_ctx->webgpu_ctx;
 
     WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
 
-    ctx->inflight_threads++;
+    ctx->global_ctx->inflight_threads++;
 
     std::vector<webgpu_command>            commands;
     std::vector<webgpu_submission_futures> futures;
@@ -2084,25 +2138,27 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
             commands.push_back(*cmd);
         }
         // compute the batch size based on the number of inflight threads
-        uint32_t inflight_threads = ctx->inflight_threads;
+        uint32_t inflight_threads = ctx->global_ctx->inflight_threads;
         uint32_t batch_size       = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)),
                                              WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
         if (commands.size() >= batch_size) {
-            futures.push_back(ggml_backend_webgpu_submit(ctx, commands));
+            futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool,
+                                                         &ctx->set_rows_error_buf_pool));
             // Process events and check for completed submissions
-            ctx->instance.ProcessEvents();
-            ggml_backend_webgpu_wait(ctx, futures, false);
+            ctx->global_ctx->instance.ProcessEvents();
+            ggml_backend_webgpu_wait(ctx->global_ctx, futures, false);
             commands.clear();
         }
     }
     if (!commands.empty()) {
-        webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands);
+        webgpu_submission_futures new_futures =
+            ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
         futures.push_back(new_futures);
     }
 
-    ggml_backend_webgpu_wait(ctx, futures);
-    ctx->inflight_threads--;
-    WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx);
+    ggml_backend_webgpu_wait(ctx->global_ctx, futures);
+    ctx->global_ctx->inflight_threads--;
+    WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
     return GGML_STATUS_SUCCESS;
 }
 
@@ -2159,8 +2215,8 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
 
     // This is a trick to set all bytes of a u32 to the same 1 byte value.
     uint32_t val32 = (uint32_t) value * 0x01010101;
-    ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size);
-    WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->webgpu_ctx);
+    ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, total_offset, size);
+    WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->global_ctx);
 }
 
 static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
@@ -2169,15 +2225,14 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
                                                   size_t                offset,
                                                   size_t                size) {
     WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor);
-    ggml_backend_webgpu_buffer_context * buf_ctx    = (ggml_backend_webgpu_buffer_context *) buffer->context;
-    webgpu_context                       webgpu_ctx = buf_ctx->webgpu_ctx;
+    ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
 
     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
                                                               << ", " << offset << ", " << size << ")");
 
     size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
 
-    webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
+    buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
 
     if (size % 4 != 0) {
         // If size is not a multiple of 4, we need to memset the remaining bytes
@@ -2190,21 +2245,21 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
             ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
         }
         // memset the remaining bytes
-        ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size),
-                                          remaining_size);
+        ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32,
+                                          total_offset + (size - remaining_size), remaining_size);
     } else {
         // wait for WriteBuffer to complete
-        webgpu_ctx->instance.WaitAny(
-            webgpu_ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
+        buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone(
+                                                  wgpu::CallbackMode::AllowSpontaneous,
                                                   [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
                                                       if (status != wgpu::QueueWorkDoneStatus::Success) {
                                                           GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
                                                                          std::string(message).c_str());
                                                       }
                                                   }),
-            UINT64_MAX);
+                                              UINT64_MAX);
     }
-    WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, webgpu_ctx);
+    WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx);
 }
 
 static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
@@ -2216,8 +2271,7 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
     ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
                                                               << ", " << offset << ", " << size << ")");
-    webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
-    wgpu::Device   device     = webgpu_ctx->device;
+    wgpu::Device device = buf_ctx->global_ctx->device;
 
     size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
 
@@ -2227,42 +2281,45 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
         final_size = size + (4 - (size % 4));
     }
 
-    std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);
+    std::lock_guard<std::recursive_mutex> lock(buf_ctx->global_ctx->mutex);
 
-    if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
+    if (buf_ctx->global_ctx->get_tensor_staging_buf == nullptr ||
+        buf_ctx->global_ctx->get_tensor_staging_buf.GetSize() < final_size) {
         // Create a new staging buffer if it doesn't exist or is too small
-        if (webgpu_ctx->get_tensor_staging_buf) {
-            webgpu_ctx->get_tensor_staging_buf.Destroy();
+        if (buf_ctx->global_ctx->get_tensor_staging_buf) {
+            buf_ctx->global_ctx->get_tensor_staging_buf.Destroy();
         }
-        ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size,
+        ggml_webgpu_create_buffer(device, buf_ctx->global_ctx->get_tensor_staging_buf, final_size,
                                   wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
     }
 
     // Copy the data from the buffer to the staging buffer
     wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
-    encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size);
+    encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, buf_ctx->global_ctx->get_tensor_staging_buf, 0,
+                               final_size);
     wgpu::CommandBuffer commands = encoder.Finish();
 
     // Submit the command buffer to the queue
-    webgpu_ctx->queue.Submit(1, &commands);
+    buf_ctx->global_ctx->queue.Submit(1, &commands);
 
     // Map the staging buffer to read the data
-    ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size);
+    ggml_backend_webgpu_map_buffer(buf_ctx->global_ctx, buf_ctx->global_ctx->get_tensor_staging_buf,
+                                   wgpu::MapMode::Read, 0, final_size);
     // Must specify size here since the staging buffer might be larger than the tensor size
-    const void * mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
+    const void * mapped_range = buf_ctx->global_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
 
     // Copy the data from the mapped range to the output buffer
     std::memcpy(data, mapped_range, size);
-    webgpu_ctx->get_tensor_staging_buf.Unmap();
-    WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, webgpu_ctx);
+    buf_ctx->global_ctx->get_tensor_staging_buf.Unmap();
+    WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, buf_ctx->global_ctx);
 }
 
 static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
     WEBGPU_CPU_PROFILE_TOTAL_START(clear);
     ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
-    ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size);
-    WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->webgpu_ctx);
+    ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, value, 0, buffer->size);
+    WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->global_ctx);
 }
 
 static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
@@ -2292,28 +2349,30 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b
     int                     buffer_id = buffer_count++;
     std::string             buf_name  = "tensor_buf" + std::to_string(buffer_id);
     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes");
-    ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
 
-    wgpu::Buffer buf;
-    ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT),
+    ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
+    wgpu::Buffer                         buf;
+    ggml_webgpu_create_buffer(ctx->webgpu_global_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT),
                               wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
                               buf_name.c_str());
 
     ggml_backend_webgpu_buffer_context * buf_ctx =
-        new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf, buf_name);
+        new ggml_backend_webgpu_buffer_context(buf, buf_name, ctx->webgpu_global_ctx);
 
     return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
 }
 
 static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
-    ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
-    return ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment;
+    ggml_backend_webgpu_device_context * dev_ctx =
+        static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
+    return dev_ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
 }
 
 // maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
 static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
-    ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
-    return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize;
+    ggml_backend_webgpu_device_context * dev_ctx =
+        static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
+    return dev_ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize;
 }
 
 static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,
@@ -2322,7 +2381,7 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
     size_t                               res = ggml_nbytes(tensor);
     switch (tensor->op) {
         case GGML_OP_ARGSORT:
-            res = ROUNDUP_POW2(res * 2 + ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment,
+            res = ROUNDUP_POW2(res * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
                                WEBGPU_STORAGE_BUF_BINDING_MULT);
             break;
         case GGML_OP_TOP_K:
@@ -2330,8 +2389,9 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
                 const ggml_tensor * src0 = tensor->src[0];
                 if (src0) {
                     const size_t full = sizeof(int32_t) * ggml_nelements(src0);
-                    res               = ROUNDUP_POW2(full * 2 + ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment,
-                                                     WEBGPU_STORAGE_BUF_BINDING_MULT);
+                    res               = ROUNDUP_POW2(
+                        full * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
+                        WEBGPU_STORAGE_BUF_BINDING_MULT);
                 }
             }
             break;
@@ -2359,7 +2419,7 @@ static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t
     ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
     // TODO: for now, return maxBufferSize as both free and total memory
     // Track https://github.com/gpuweb/gpuweb/issues/5505 for updates.
-    uint64_t                             max_buffer_size = ctx->webgpu_ctx->limits.maxBufferSize;
+    uint64_t                             max_buffer_size = ctx->webgpu_global_ctx->capabilities.limits.maxBufferSize;
     // If we're on a 32-bit system, clamp to UINTPTR_MAX
 #if UINTPTR_MAX < UINT64_MAX
     uint64_t max_ptr_size = static_cast<uint64_t>(UINTPTR_MAX);
@@ -2402,66 +2462,67 @@ static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_si
     return constants;
 }
 
-static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
+static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
     // we use the maximum workgroup size for the memset pipeline
-    size_t max_threads                  = WEBGPU_MAX_WG_SIZE * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
+    size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
     // Size the bytes_per_thread so that the largest buffer size can be handled
-    webgpu_ctx->memset_bytes_per_thread = CEIL_DIV(webgpu_ctx->limits.maxStorageBufferBindingSize, max_threads);
+    ctx->capabilities.memset_bytes_per_thread =
+        CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads);
     std::vector<wgpu::ConstantEntry> constants(2);
-    constants[0].key                = "wg_size";
-    constants[0].value              = WEBGPU_MAX_WG_SIZE;
-    constants[1].key                = "bytes_per_thread";
-    constants[1].value              = webgpu_ctx->memset_bytes_per_thread;
-    webgpu_ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_memset, "memset", constants);
+    constants[0].key         = "wg_size";
+    constants[0].value       = WEBGPU_MAX_WG_SIZE;
+    constants[1].key         = "bytes_per_thread";
+    constants[1].value       = ctx->capabilities.memset_bytes_per_thread;
+    ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
 }
 
 static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
     // Q4/Q5/Q8 classic quantizations
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
 
     // K-quantizations
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
 
     // IQ quantizations (2-, 3-, 4-bit variants)
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
 
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
 
     // 1-bit and 4-bit IQ variants
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
 
     std::string proc_mul_mat_f32_f32;
     std::string proc_mul_mat_f32_f32_vec;
@@ -2474,18 +2535,18 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
 
     std::vector<wgpu::ConstantEntry> mul_mat_constants;
 #ifndef __EMSCRIPTEN__
-    if (webgpu_ctx->supports_subgroup_matrix) {
+    if (webgpu_ctx->global_ctx->capabilities.supports_subgroup_matrix) {
         std::map<std::string, std::string> sg_matrix_repls;
-        sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->max_subgroup_size);
+        sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] =
+            std::to_string(webgpu_ctx->global_ctx->capabilities.max_subgroup_size);
         sg_matrix_repls["WEBGPU_TILE_K"]            = std::to_string(WEBGPU_MUL_MAT_TILE_K);
         sg_matrix_repls["WEBGPU_SUBGROUP_M"]        = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
         sg_matrix_repls["WEBGPU_SUBGROUP_N"]        = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
         sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
         sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
-        sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"]     = std::to_string(webgpu_ctx->sg_mat_m);
-        sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"]     = std::to_string(webgpu_ctx->sg_mat_n);
-        sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"]     = std::to_string(webgpu_ctx->sg_mat_k);
-
+        sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"]     = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_m);
+        sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"]     = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_n);
+        sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"]     = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_k);
         proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
         proc_mul_mat_f32_f32_vec =
             ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls);
@@ -2522,21 +2583,21 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
 #endif
 
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants);
+        webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants);
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants);
+        webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants);
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants);
+        webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants);
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants);
+        webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants);
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants);
+        webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants);
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants);
+        webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants);
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants);
+        webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants);
     webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants);
+        webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants);
 
     std::vector<wgpu::ConstantEntry> mul_mat_vec_constants(3);
     mul_mat_vec_constants[0].key   = "WORKGROUP_SIZE";
@@ -2547,171 +2608,171 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
     mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
 
     webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants);
+        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants);
     webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants);
+        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants);
     webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants);
+        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants);
     webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants);
+        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants);
     webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants);
+        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants);
     webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants);
+        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants);
     webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants);
+        webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants);
 }
 
 static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
     std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
 
     webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants);
+    webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants);
 
     webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants);
     webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants);
     webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants);
     webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants);
     webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants);
     webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants);
     webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants);
 
     webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants);
     webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants);
     webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants);
     webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants);
     webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants);
 
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
+    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
     webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
     webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants);
+    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
     webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants);
     webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants);
     webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants);
     webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
     webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
 }
 
 static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
     std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
 
     webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
     webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants);
     webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
     webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
     webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
 }
 
 static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
     std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
 
     webgpu_ctx->add_pipelines[GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32, "add_f32", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f32, "add_f32", constants);
     webgpu_ctx->add_pipelines[GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16, "add_f16", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f16, "add_f16", constants);
     webgpu_ctx->add_pipelines[GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants);
     webgpu_ctx->add_pipelines[GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants);
 }
 
 static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
     std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
 
     webgpu_ctx->sub_pipelines[GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32, "sub_f32", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f32, "sub_f32", constants);
     webgpu_ctx->sub_pipelines[GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16, "sub_f16", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f16, "sub_f16", constants);
     webgpu_ctx->sub_pipelines[GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants);
     webgpu_ctx->sub_pipelines[GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants);
 }
 
 static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
     std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
 
     webgpu_ctx->mul_pipelines[GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32, "mul_f32", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f32, "mul_f32", constants);
     webgpu_ctx->mul_pipelines[GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16, "mul_f16", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f16, "mul_f16", constants);
     webgpu_ctx->mul_pipelines[GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants);
     webgpu_ctx->mul_pipelines[GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants);
 }
 
 static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
     std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
 
     webgpu_ctx->div_pipelines[GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32, "div_f32", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f32, "div_f32", constants);
     webgpu_ctx->div_pipelines[GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16, "div_f16", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f16, "div_f16", constants);
     webgpu_ctx->div_pipelines[GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants);
     webgpu_ctx->div_pipelines[GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants);
 }
 
 static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
     std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
 
     webgpu_ctx->rms_norm_pipelines[0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm, "rms_norm", constants);
-    webgpu_ctx->rms_norm_pipelines[1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants);
+    webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
 }
 
 static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
     std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
 
     webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32, "rope_f32", constants);
-    webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants);
+    webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
     webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
-    webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
+    webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
 
     webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16, "rope_f16", constants);
-    webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants);
+    webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
     webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
-    webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
+    webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
 }
 
 static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
@@ -2719,68 +2780,68 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
 
     // REGLU
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
 
     // GEGLU
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
 
     // SWIGLU
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
+    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
+    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
 
     // SWIGLU_OAI
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
+    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
 
     // GEGLU_ERF
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
+    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
+    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
 
     // GEGLU_QUICK
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
+    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
+    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
 }
 
 static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
     std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
 
     webgpu_ctx->scale_pipelines[0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32, "scale_f32", constants);
-    webgpu_ctx->scale_pipelines[1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32_inplace, "scale_f32_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32, "scale_f32", constants);
+    webgpu_ctx->scale_pipelines[1] = ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32_inplace,
+                                                                 "scale_f32_inplace", constants);
 }
 
 static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
@@ -2788,56 +2849,243 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
 
     // f32 (no mask)
     webgpu_ctx->soft_max_pipelines[2][0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
-    webgpu_ctx->soft_max_pipelines[2][0][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
-    webgpu_ctx->soft_max_pipelines[2][1][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
+    webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
+    webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
     webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
+        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
 
     // f32 mask (mask_type = 0)
-    webgpu_ctx->soft_max_pipelines[0][0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
+    webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
     webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
+        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
     webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
-    webgpu_ctx->soft_max_pipelines[0][1][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace", constants);
+        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
+    webgpu_ctx->soft_max_pipelines[0][1][1] =
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace,
+                                    "soft_max_f32_mask_f32_sink_inplace", constants);
 
     // f16 mask (mask_type = 1)
-    webgpu_ctx->soft_max_pipelines[1][0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
+    webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
     webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
+        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
     webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
-    webgpu_ctx->soft_max_pipelines[1][1][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants);
+        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
+    webgpu_ctx->soft_max_pipelines[1][1][1] =
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace,
+                                    "soft_max_f32_mask_f16_sink_inplace", constants);
+}
+
+static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
+    wgpu::RequestAdapterOptions options = {};
+
+#ifndef __EMSCRIPTEN__
+    // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
+    const char * const          adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
+    wgpu::DawnTogglesDescriptor adapterTogglesDesc;
+    adapterTogglesDesc.enabledToggles     = adapterEnabledToggles;
+    adapterTogglesDesc.enabledToggleCount = 2;
+    options.nextInChain                   = &adapterTogglesDesc;
+#endif
+
+    ctx->webgpu_global_ctx->instance.WaitAny(
+        ctx->webgpu_global_ctx->instance.RequestAdapter(
+            &options, wgpu::CallbackMode::AllowSpontaneous,
+            [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
+                if (status != wgpu::RequestAdapterStatus::Success) {
+                    GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
+                    return;
+                }
+                ctx->webgpu_global_ctx->adapter = std::move(adapter);
+            }),
+        UINT64_MAX);
+    GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr);
+
+    ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits);
+
+    wgpu::AdapterInfo info{};
+#ifndef __EMSCRIPTEN__
+    wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
+    if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
+        info.nextInChain = &subgroup_matrix_configs;
+    }
+#endif
+    ctx->webgpu_global_ctx->adapter.GetInfo(&info);
+    wgpu::SupportedFeatures features;
+    ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
+    // we require f16 support
+    GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
+
+#ifndef __EMSCRIPTEN__
+    // Only support square f16 matrices of size 8 or 16 for now
+    bool valid_subgroup_matrix_config = false;
+    if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
+        for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
+            const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
+            if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
+                config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
+                config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
+                ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M;
+                ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N;
+                ctx->webgpu_global_ctx->capabilities.sg_mat_k = config.K;
+                valid_subgroup_matrix_config                  = true;
+                break;
+            }
+        }
+    }
+    ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
+#endif
+
+    // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
+    // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
+    ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize;
+    // Initialize device
+    std::vector<wgpu::FeatureName> required_features       = { wgpu::FeatureName::ShaderF16 };
+
+#ifndef __EMSCRIPTEN__
+    required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
+    if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
+        required_features.push_back(wgpu::FeatureName::Subgroups);
+        required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
+    }
+#endif
+
+#ifdef GGML_WEBGPU_GPU_PROFILE
+    required_features.push_back(wgpu::FeatureName::TimestampQuery);
+#endif
+
+    wgpu::DeviceDescriptor dev_desc;
+    dev_desc.requiredLimits       = &ctx->webgpu_global_ctx->capabilities.limits;
+    dev_desc.requiredFeatures     = required_features.data();
+    dev_desc.requiredFeatureCount = required_features.size();
+    dev_desc.SetDeviceLostCallback(
+        wgpu::CallbackMode::AllowSpontaneous,
+        [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
+            GGML_UNUSED(device);
+            GGML_UNUSED(reason);
+            GGML_UNUSED(message);
+            //TODO: uncomment once proper free logic is in place
+            //GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
+            //std::string(message).c_str());
+        });
+    dev_desc.SetUncapturedErrorCallback(
+        [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
+            GGML_UNUSED(device);
+            GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
+                       std::string(message).c_str());
+        });
+
+#ifndef __EMSCRIPTEN__
+    // Enable Dawn-specific toggles to increase native performance
+    // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
+    //       only for native performance?
+    const char * const deviceEnabledToggles[]  = { "skip_validation", "disable_robustness", "disable_workgroup_init",
+                                                   "disable_polyfills_on_integer_div_and_mod" };
+    const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
+    wgpu::DawnTogglesDescriptor deviceTogglesDesc;
+    deviceTogglesDesc.enabledToggles      = deviceEnabledToggles;
+    deviceTogglesDesc.enabledToggleCount  = 4;
+    deviceTogglesDesc.disabledToggles     = deviceDisabledToggles;
+    deviceTogglesDesc.disabledToggleCount = 1;
+
+    dev_desc.nextInChain = &deviceTogglesDesc;
+#endif
+
+    ctx->webgpu_global_ctx->instance.WaitAny(
+        ctx->webgpu_global_ctx->adapter.RequestDevice(
+            &dev_desc, wgpu::CallbackMode::AllowSpontaneous,
+            [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
+                if (status != wgpu::RequestDeviceStatus::Success) {
+                    GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
+                    return;
+                }
+                ctx->webgpu_global_ctx->device = std::move(device);
+            }),
+        UINT64_MAX);
+    GGML_ASSERT(ctx->webgpu_global_ctx->device != nullptr);
+
+    ggml_webgpu_init_memset_pipeline(ctx->webgpu_global_ctx);
+    ctx->webgpu_global_ctx->memset_buf_pool.init(ctx->webgpu_global_ctx->device, 1, WEBGPU_PARAMS_BUF_SIZE_BYTES,
+                                                 wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
+                                                 wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
+    ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue();
+
+#ifdef GGML_WEBGPU_GPU_PROFILE
+    // Initialize buffer pool for timestamp queries, used for profiling
+    ctx->webgpu_global_ctx->timestamp_query_buf_pool.init(ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS,
+                                              WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
+                                              wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
+                                              wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
+#endif
+
+    GGML_LOG_INFO(
+        "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
+        "device_desc: %s\n",
+        info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
+        std::string(info.device).c_str(), std::string(info.description).c_str());
+    return true;
+}
+
+static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
+    ggml_backend_webgpu_device_context * dev_ctx    = (ggml_backend_webgpu_device_context *) dev->context;
+    webgpu_context                       webgpu_ctx = std::make_shared<webgpu_context_struct>();
+    webgpu_ctx->global_ctx                          = dev_ctx->webgpu_global_ctx;
+    webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
+                                    wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
+                                    wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
+    webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
+                                             WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
+                                             wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
+                                             wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
+
+    ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
+    ggml_webgpu_init_get_rows_pipeline(webgpu_ctx);
+    ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
+    ggml_webgpu_init_add_pipeline(webgpu_ctx);
+    ggml_webgpu_init_sub_pipeline(webgpu_ctx);
+    ggml_webgpu_init_mul_pipeline(webgpu_ctx);
+    ggml_webgpu_init_div_pipeline(webgpu_ctx);
+    ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
+    ggml_webgpu_init_rope_pipeline(webgpu_ctx);
+    ggml_webgpu_init_glu_pipeline(webgpu_ctx);
+    ggml_webgpu_init_scale_pipeline(webgpu_ctx);
+    ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
+#ifdef GGML_WEBGPU_DEBUG
+    // Initialize debug buffers
+    ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf,
+                              WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
+                              wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
+    ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_dev_buf,
+                              WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
+                              wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
+#endif
+    return webgpu_ctx;
 }
 
-// TODO: move most initialization logic here
-static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
+static ggml_backend_t ggml_backend_webgpu_backend_init(ggml_backend_dev_t dev, const char * params) {
     GGML_UNUSED(params);
 
-    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()");
+    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_backend_init()");
 
-    ggml_backend_webgpu_device_context * dev_ctx    = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
-    webgpu_context                       webgpu_ctx = dev_ctx->webgpu_ctx;
+    ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
 
-    static ggml_backend_webgpu_context backend_ctx;
-    backend_ctx.name       = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
-    backend_ctx.webgpu_ctx = webgpu_ctx;
+    auto * backend_ctx      = new ggml_backend_webgpu_context();
+    backend_ctx->name       = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
+    backend_ctx->webgpu_ctx = initialize_webgpu_context(dev);
 
     // See GGML Backend Interface section
-    static ggml_backend backend = {
+    auto * backend = new ggml_backend();
+    *backend       = {
         /* .guid      = */ ggml_backend_webgpu_guid(),
         /* .interface = */ ggml_backend_webgpu_i,
         /* .device    = */ dev,
-        /* .context   = */ &backend_ctx,
+        /* .context   = */ backend_ctx,
     };
-    return &backend;
+    return backend;
 }
 
 static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) {
@@ -2854,7 +3102,8 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
         },
         /* .device  = */
         dev,
-        /* .context = */ NULL,
+        /* .context = */
+        NULL
     };
 
     return &ggml_backend_webgpu_buffer_type;
@@ -2895,16 +3144,16 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) {
 static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
     ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
 
-    webgpu_context webgpu_ctx = ctx->webgpu_ctx;
-
     ggml_tensor * src0 = op->src[0];
     ggml_tensor * src1 = op->src[1];
     ggml_tensor * src2 = op->src[2];
 
     // on smaller devices (or CI), tensors may be larger than the max storage buffer size
-    if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
-        (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
-        (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
+    if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize ||
+        (src0 != nullptr &&
+         ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
+        (src1 != nullptr &&
+         ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) {
         return false;
     }
 
@@ -2984,17 +3233,19 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
             }
         case GGML_OP_FLASH_ATTN_EXT:
             {
-                if (!webgpu_ctx->supports_subgroup_matrix) {
+#ifndef __EMSCRIPTEN__
+                if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
                     break;
                 }
                 // Head dimensions must fit in workgroup memory with minimum tile sizes
-                size_t     limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize;
+                size_t     limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
                 const bool has_mask    = op->src[3] != nullptr;
-                const bool kv_direct   = src1->type == GGML_TYPE_F16 && (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 &&
+                const bool kv_direct   = src1->type == GGML_TYPE_F16 &&
+                                       (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
                                        (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
                 const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
-                    webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0],
-                    has_mask, kv_direct);
+                    ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
+                    (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
                 if (min_bytes > limit_bytes) {
                     break;
                 }
@@ -3003,6 +3254,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
                               (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
                                src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
                               src2->type == src1->type && op->type == GGML_TYPE_F32;
+#endif
                 break;
             }
         case GGML_OP_RMS_NORM:
@@ -3099,10 +3351,13 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
         default:
             break;
     }
-    if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
-        (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
-        (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
-        (src2 != nullptr && ggml_nbytes(src2) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
+    if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize ||
+        (src0 != nullptr &&
+         ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
+        (src1 != nullptr &&
+         ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
+        (src2 != nullptr &&
+         ggml_nbytes(src2) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) {
         supports_op = false;
         WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: ");
     }
@@ -3127,7 +3382,7 @@ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
     /* .get_memory           = */ ggml_backend_webgpu_device_get_memory,
     /* .get_type             = */ ggml_backend_webgpu_device_get_type,
     /* .get_props            = */ ggml_backend_webgpu_device_get_props,
-    /* .init_backend         = */ ggml_backend_webgpu_device_init,
+    /* .init_backend         = */ ggml_backend_webgpu_backend_init,
     /* .get_buffer_type      = */ ggml_backend_webgpu_device_get_buffer_type,
     /* .get_host_buffer_type = */ NULL,
     /* .buffer_from_host_ptr = */ NULL,
@@ -3156,6 +3411,7 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
 // TODO: Does this need to be thread safe? Is it only called once?
 // TODO: move most logic to device_init function so backend can be freed/initialized properly
 // Only one device is supported for now
+
 static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
     GGML_ASSERT(index == 0);
     WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()");
@@ -3164,189 +3420,12 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
 
     ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
 
-    webgpu_context ctx = reg_ctx->webgpu_ctx;
-
-    wgpu::RequestAdapterOptions options = {};
-
-#ifndef __EMSCRIPTEN__
-    // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
-    const char * const          adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
-    wgpu::DawnTogglesDescriptor adapterTogglesDesc;
-    adapterTogglesDesc.enabledToggles     = adapterEnabledToggles;
-    adapterTogglesDesc.enabledToggleCount = 2;
-    options.nextInChain                   = &adapterTogglesDesc;
-#endif
-
-    ctx->instance.WaitAny(ctx->instance.RequestAdapter(
-                              &options, wgpu::CallbackMode::AllowSpontaneous,
-                              [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
-                                  if (status != wgpu::RequestAdapterStatus::Success) {
-                                      GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
-                                      return;
-                                  }
-                                  ctx->adapter = std::move(adapter);
-                              }),
-                          UINT64_MAX);
-    GGML_ASSERT(ctx->adapter != nullptr);
-
-    ctx->adapter.GetLimits(&ctx->limits);
-
-    wgpu::AdapterInfo info{};
-#ifndef __EMSCRIPTEN__
-    wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
-    if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
-        info.nextInChain = &subgroup_matrix_configs;
-    }
-#endif
-    ctx->adapter.GetInfo(&info);
-
-    wgpu::SupportedFeatures features;
-    ctx->adapter.GetFeatures(&features);
-    // we require f16 support
-    GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
-
-#ifndef __EMSCRIPTEN__
-    // Only support square f16 matrices of size 8 or 16 for now
-    bool valid_subgroup_matrix_config = false;
-    if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
-        for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
-            const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
-            if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
-                config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
-                config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
-                ctx->sg_mat_m                = config.M;
-                ctx->sg_mat_n                = config.N;
-                ctx->sg_mat_k                = config.K;
-                valid_subgroup_matrix_config = true;
-                break;
-            }
-        }
-    }
-
-    ctx->supports_subgroup_matrix = valid_subgroup_matrix_config;
-#endif
-    // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
-    // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
-    ctx->max_subgroup_size = info.subgroupMaxSize;
-
-    // Initialize device
-    std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
-
-#ifndef __EMSCRIPTEN__
-    required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
-    if (ctx->supports_subgroup_matrix) {
-        required_features.push_back(wgpu::FeatureName::Subgroups);
-        required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
-    }
-#endif
-
-#ifdef GGML_WEBGPU_GPU_PROFILE
-    required_features.push_back(wgpu::FeatureName::TimestampQuery);
-#endif
-
-    wgpu::DeviceDescriptor dev_desc;
-    dev_desc.requiredLimits       = &ctx->limits;
-    dev_desc.requiredFeatures     = required_features.data();
-    dev_desc.requiredFeatureCount = required_features.size();
-    dev_desc.SetDeviceLostCallback(
-        wgpu::CallbackMode::AllowSpontaneous,
-        [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
-            GGML_UNUSED(device);
-            GGML_UNUSED(reason);
-            GGML_UNUSED(message);
-            //TODO: uncomment once proper free logic is in place
-            //GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
-            //std::string(message).c_str());
-        });
-    dev_desc.SetUncapturedErrorCallback(
-        [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
-            GGML_UNUSED(device);
-            GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
-                       std::string(message).c_str());
-        });
-
-#ifndef __EMSCRIPTEN__
-    // Enable Dawn-specific toggles to increase native performance
-    // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
-    //       only for native performance?
-    const char * const deviceEnabledToggles[]  = { "skip_validation", "disable_robustness", "disable_workgroup_init",
-                                                   "disable_polyfills_on_integer_div_and_mod" };
-    const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
-    wgpu::DawnTogglesDescriptor deviceTogglesDesc;
-    deviceTogglesDesc.enabledToggles      = deviceEnabledToggles;
-    deviceTogglesDesc.enabledToggleCount  = 4;
-    deviceTogglesDesc.disabledToggles     = deviceDisabledToggles;
-    deviceTogglesDesc.disabledToggleCount = 1;
-
-    dev_desc.nextInChain = &deviceTogglesDesc;
-#endif
-
-    ctx->instance.WaitAny(ctx->adapter.RequestDevice(
-                              &dev_desc, wgpu::CallbackMode::AllowSpontaneous,
-                              [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
-                                  if (status != wgpu::RequestDeviceStatus::Success) {
-                                      GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n",
-                                                     std::string(message).c_str());
-                                      return;
-                                  }
-                                  ctx->device = std::move(device);
-                              }),
-                          UINT64_MAX);
-    GGML_ASSERT(ctx->device != nullptr);
-
-    // Initialize (compute) queue
-    ctx->queue = ctx->device.GetQueue();
-
-    // Create buffer pool for shader parameters
-    ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
-                             wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
-                             wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
-
-#ifdef GGML_WEBGPU_GPU_PROFILE
-    // Initialize buffer pool for timestamp queries (profiling)
-    ctx->timestamp_query_buf_pool.init(ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS,
-                                       WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
-                                       wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
-                                       wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
-#endif
-
-    ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
-                                      wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
-                                      wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
-
-    ggml_webgpu_init_memset_pipeline(ctx);
-    ggml_webgpu_init_mul_mat_pipeline(ctx);
-    ggml_webgpu_init_get_rows_pipeline(ctx);
-    ggml_webgpu_init_cpy_pipeline(ctx);
-    ggml_webgpu_init_add_pipeline(ctx);
-    ggml_webgpu_init_sub_pipeline(ctx);
-    ggml_webgpu_init_mul_pipeline(ctx);
-    ggml_webgpu_init_div_pipeline(ctx);
-    ggml_webgpu_init_rms_norm_pipeline(ctx);
-    ggml_webgpu_init_rope_pipeline(ctx);
-    ggml_webgpu_init_glu_pipeline(ctx);
-    ggml_webgpu_init_scale_pipeline(ctx);
-    ggml_webgpu_init_soft_max_pipeline(ctx);
-
-#ifdef GGML_WEBGPU_DEBUG
-    // Initialize debug buffers
-    ggml_webgpu_create_buffer(ctx->device, ctx->debug_host_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
-                              wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
-    ggml_webgpu_create_buffer(ctx->device, ctx->debug_dev_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
-                              wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
-#endif
+    create_webgpu_device(reg_ctx);
 
     static ggml_backend_webgpu_device_context device_ctx;
-    device_ctx.webgpu_ctx  = ctx;
-    device_ctx.device_name = GGML_WEBGPU_NAME;
-    device_ctx.device_desc = info.description;
-
-    GGML_LOG_INFO(
-        "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
-        "device_desc: %s\n",
-        info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
-        std::string(info.device).c_str(), std::string(info.description).c_str());
-
+    device_ctx.device_name            = GGML_WEBGPU_NAME;
+    device_ctx.device_desc            = GGML_WEBGPU_NAME;
+    device_ctx.webgpu_global_ctx      = reg_ctx->webgpu_global_ctx;
     // See GGML Backend Device Interface section
     static ggml_backend_device device = {
         /* .iface   = */ ggml_backend_webgpu_device_i,
@@ -3354,7 +3433,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
         /* .context = */ &device_ctx,
     };
 
-    WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, ctx);
+    WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, reg_ctx->webgpu_global_ctx);
     return &device;
 }
 
@@ -3370,10 +3449,7 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
 ggml_backend_reg_t ggml_backend_webgpu_reg() {
     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
 
-    webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
-
     static ggml_backend_webgpu_reg_context ctx;
-    ctx.webgpu_ctx   = webgpu_ctx;
     ctx.name         = GGML_WEBGPU_NAME;
     ctx.device_count = 1;
 
@@ -3390,15 +3466,17 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
     instance_descriptor.nextInChain        = &instanceTogglesDesc;
 #endif
 
-    webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
+    wgpu::Instance inst             = wgpu::CreateInstance(&instance_descriptor);
+    ctx.webgpu_global_ctx           = webgpu_global_context(new webgpu_global_context_struct());
+    ctx.webgpu_global_ctx->instance = std::move(inst);
 
 #ifdef __EMSCRIPTEN__
-    if (webgpu_ctx->instance == nullptr) {
+    if (ctx.webgpu_global_ctx->instance == nullptr) {
         GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n");
         return nullptr;
     }
 #endif
-    GGML_ASSERT(webgpu_ctx->instance != nullptr);
+    GGML_ASSERT(ctx.webgpu_global_ctx->instance != nullptr);
 
     static ggml_backend_reg reg = {
         /* .api_version = */ GGML_BACKEND_API_VERSION,
@@ -3411,7 +3489,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
 ggml_backend_t ggml_backend_webgpu_init(void) {
     ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
 
-    return ggml_backend_webgpu_device_init(dev, nullptr);
+    return ggml_backend_webgpu_backend_init(dev, nullptr);
 }
 
 GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg)