]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-webgpu: move from parameter buffer pool to single buffer with offsets (#21278)
authorReese Levine <redacted>
Fri, 3 Apr 2026 18:40:14 +0000 (11:40 -0700)
committerGitHub <redacted>
Fri, 3 Apr 2026 18:40:14 +0000 (11:40 -0700)
* Work towards removing bitcast

* Move rest of existing types over

* Add timeout back to wait and remove synchronous set_tensor/memset_tensor

* move to unpackf16 for wider compatibility

* cleanup

* Remove deadlock condition in free_bufs

* Start work on removing parameter buffer pools

* Simplify and optimize further

* simplify profile futures

* Fix stride

* Try using a single command buffer per batch

* formatting

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
ggml/src/ggml-webgpu/ggml-webgpu.cpp

index 1c56c689312f4d643ecbbc13d2a73e2721484163..669d2cd53a8354322745252f67c39b5d77c2e098 100644 (file)
@@ -437,12 +437,18 @@ inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_
 
     // Head-dim specializations used by the tuned vec f16 path.
     switch (key.head_dim_qk) {
-        case 64: return 2u;
-        case 96: return 4u;
-        case 128: return 1u;
-        case 192: return 2u;
-        case 576: return 2u;
-        default: return 1u;
+        case 64:
+            return 2u;
+        case 96:
+            return 4u;
+        case 128:
+            return 1u;
+        case 192:
+            return 2u;
+        case 576:
+            return 2u;
+        default:
+            return 1u;
     }
 }
 
@@ -513,9 +519,9 @@ struct ggml_webgpu_flash_attn_blk_shader_lib_context {
 };
 
 inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader(
-    pre_wgsl::Preprocessor &                                    preprocessor,
-    const char *                                                shader_src,
-    const ggml_webgpu_flash_attn_blk_shader_lib_context &       context) {
+    pre_wgsl::Preprocessor &                              preprocessor,
+    const char *                                          shader_src,
+    const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
     std::vector<std::string> defines;
     std::string              variant = "flash_attn_vec_blk";
 
@@ -1857,9 +1863,8 @@ class ggml_webgpu_shader_lib {
         defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
 
         uint32_t q_tile  = context.sg_mat_m;
-        uint32_t kv_tile =
-            std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
-                     context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
+        uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
+                                    context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
         if (context.key.use_vec) {
             q_tile  = 1;
             kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context)));
@@ -1885,14 +1890,14 @@ class ggml_webgpu_shader_lib {
         }
         defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
 
-        const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn;
+        const char *    shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn;
         webgpu_pipeline pipeline =
             ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
-        auto decisions     = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
-        decisions->q_tile  = q_tile;
-        decisions->kv_tile = kv_tile;
-        decisions->wg_size = wg_size;
-        pipeline.context   = decisions;
+        auto decisions                    = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
+        decisions->q_tile                 = q_tile;
+        decisions->kv_tile                = kv_tile;
+        decisions->wg_size                = wg_size;
+        pipeline.context                  = decisions;
         flash_attn_pipelines[context.key] = pipeline;
         return flash_attn_pipelines[context.key];
     }
@@ -1905,7 +1910,7 @@ class ggml_webgpu_shader_lib {
 
         ggml_webgpu_processed_shader processed =
             ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context);
-        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
+        webgpu_pipeline pipeline              = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
         flash_attn_blk_pipelines[context.key] = pipeline;
         return flash_attn_blk_pipelines[context.key];
     }
index e53281bfbbd4df39348dcf1769624e8b419ea3c3..5c567dc0df07bab5a77642a4cbf0067842b7b557 100644 (file)
@@ -81,12 +81,10 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
 
 /* Constants */
 
-#define WEBGPU_NUM_PARAM_BUFS                96u
-#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE     32u
+#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u
+#define WEBGPU_NUM_PARAM_SLOTS \
+    (WEBGPU_COMMAND_SUBMIT_BATCH_SIZE + 10)  // a few extra for safety, since some operations may need multiple slots
 #define WEBGPU_WAIT_ANY_TIMEOUT_MS           100
-// Maximum number of in-flight submissions per-thread, to avoid exhausting the
-// parameter buffer pool
-#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD  (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE)
 #define WEBGPU_PARAMS_BUF_SIZE_BYTES         128  // enough for 32 parameters
 #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
 #define WEBGPU_STORAGE_BUF_BINDING_MULT      4    // a storage buffer binding size must be a multiple of 4
@@ -122,87 +120,45 @@ static void ggml_webgpu_create_buffer(wgpu::Device &    device,
                                       wgpu::BufferUsage usage,
                                       const char *      label);
 
-// Holds a pool of parameter buffers for WebGPU operations
-struct webgpu_buf_pool {
-    std::vector<wgpu::Buffer> free;
-
-    // The pool must be synchronized because
-    // 1. The memset pool is shared globally by every ggml buffer,
-    // since allocating a pool per ggml buffer would consume too much memory.
-    // 2. For the per-thread buffer pools in webgpu_context,
-    // buffers are allocated and freed in Dawn callbacks,
-    // which can run on a different thread than the calling thread.
-    std::mutex              mutex;
-    std::condition_variable cv;
-    size_t                  cur_pool_size;
-    size_t                  max_pool_size;
-    wgpu::Device            device;
-    wgpu::BufferUsage       dev_buf_usage;
-    size_t                  buf_size;
-    bool                    should_grow;
-
-    void init(wgpu::Device      device,
-              int               num_bufs,
-              size_t            buf_size,
-              wgpu::BufferUsage dev_buf_usage,
-              bool              should_grow   = false,
-              size_t            max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) {
-        this->max_pool_size = max_pool_size;
-        this->cur_pool_size = num_bufs;
-        this->device        = device;
-        this->dev_buf_usage = dev_buf_usage;
-        this->buf_size      = buf_size;
-        this->should_grow   = should_grow;
-        for (int i = 0; i < num_bufs; i++) {
-            wgpu::Buffer dev_buf;
-            ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
-            free.push_back(dev_buf);
+// Slot-based parameter arena for compute graph encoding. Each encoded kernel
+// gets a unique uniform-buffer slice within the current batch, and the slot
+// cursor is reset immediately after that batch is submitted.
+struct webgpu_param_arena {
+    wgpu::Buffer buffer;
+    size_t       slot_stride = 0;
+    size_t       slot_size   = 0;
+    uint32_t     slot_count  = 0;
+    uint32_t     next_slot   = 0;
+
+    void init(wgpu::Device device, size_t slot_size, uint32_t slot_count, size_t alignment) {
+        this->slot_stride = ROUNDUP_POW2(slot_size, alignment);
+        this->slot_size   = slot_size;
+        this->slot_count  = slot_count;
+        this->next_slot   = 0;
+
+        ggml_webgpu_create_buffer(device, buffer, this->slot_stride * slot_count,
+                                  wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, "ggml_webgpu_param_arena");
+    }
+
+    size_t alloc_slot(size_t size) {
+        GGML_ASSERT(size <= slot_size);
+        if (next_slot >= slot_count) {
+            GGML_ABORT("ggml_webgpu: parameter arena exhausted while encoding a batch");
         }
-    }
 
-    wgpu::Buffer alloc_bufs() {
-        std::unique_lock<std::mutex> lock(mutex);
-        if (!free.empty()) {
-            wgpu::Buffer buf = free.back();
-            free.pop_back();
-            return buf;
-        }
-
-        // Try growing the pool if no free buffers
-        if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
-            cur_pool_size++;
-            lock.unlock();  // avoid deadlock between this lock and Dawn's internal locks when buffers are freed in callbacks
-            wgpu::Buffer dev_buf;
-            ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
-
-            if (!dev_buf) {
-                GGML_ABORT("webgpu_buf_pool: failed to allocate buffers");
-            }
-            return dev_buf;
-        }
-        cv.wait(lock, [this] { return !free.empty(); });
-        wgpu::Buffer buf = free.back();
-        free.pop_back();
-        return buf;
+        return slot_stride * next_slot++;
     }
 
-    void free_bufs(std::vector<wgpu::Buffer> bufs) {
-        std::lock_guard<std::mutex> lock(mutex);
-        free.insert(free.end(), bufs.begin(), bufs.end());
-        cv.notify_all();
-    }
+    void reset() { next_slot = 0; }
 
     void cleanup() {
-        std::lock_guard<std::mutex> lock(mutex);
-        for (auto & buf : free) {
-            if (buf) {
-                buf.Destroy();
-            }
+        if (buffer) {
+            buffer.Destroy();
+            buffer = nullptr;
         }
-        free.clear();
     }
 
-    ~webgpu_buf_pool() { this->cleanup(); }
+    ~webgpu_param_arena() { this->cleanup(); }
 };
 
 #ifdef GGML_WEBGPU_GPU_PROFILE
@@ -269,10 +225,8 @@ struct webgpu_gpu_profile_buf_pool {
 };
 #endif
 
-struct webgpu_command {
-    uint32_t                  num_kernels;
-    wgpu::CommandBuffer       commands;
-    std::vector<wgpu::Buffer> params_bufs;
+struct webgpu_encoded_op {
+    uint32_t num_kernels = 0;
 #ifdef GGML_WEBGPU_GPU_PROFILE
     webgpu_gpu_profile_bufs timestamp_query_bufs;
     std::string             pipeline_name;
@@ -305,8 +259,8 @@ struct webgpu_global_context_struct {
     // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches.
     std::recursive_mutex mutex;
 
-    webgpu_buf_pool                memset_buf_pool;
-    std::map<int, webgpu_pipeline> memset_pipelines;  // variant or type index
+    wgpu::Buffer    memset_params_buf;
+    webgpu_pipeline memset_pipeline;
 
 #ifdef GGML_WEBGPU_CPU_PROFILE
     // Profiling: labeled CPU time in ms (total)
@@ -332,6 +286,10 @@ struct webgpu_global_context_struct {
             this->get_tensor_staging_buf.Destroy();
             this->get_tensor_staging_buf = nullptr;
         }
+        if (this->memset_params_buf) {
+            this->memset_params_buf.Destroy();
+            this->memset_params_buf = nullptr;
+        }
 #ifdef GGML_WEBGPU_DEBUG
         if (this->debug_host_buf) {
             this->debug_host_buf.Destroy();
@@ -347,13 +305,6 @@ struct webgpu_global_context_struct {
 
 typedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context;
 
-struct webgpu_submission {
-    wgpu::FutureWaitInfo submit_done;
-#ifdef GGML_WEBGPU_GPU_PROFILE
-    std::vector<wgpu::FutureWaitInfo> profile_futures;
-#endif
-};
-
 // All the base objects needed to run operations on a WebGPU device
 struct webgpu_context_struct {
     // Points to global instances owned by ggml_backend_webgpu_reg_context
@@ -361,9 +312,9 @@ struct webgpu_context_struct {
 
     std::unique_ptr<ggml_webgpu_shader_lib> shader_lib;
 
-    webgpu_buf_pool param_buf_pool;
-    wgpu::Buffer    set_rows_dev_error_buf;
-    wgpu::Buffer    set_rows_host_error_buf;
+    webgpu_param_arena param_arena;
+    wgpu::Buffer       set_rows_dev_error_buf;
+    wgpu::Buffer       set_rows_host_error_buf;
 
     size_t memset_bytes_per_thread;
 };
@@ -448,95 +399,34 @@ static void ggml_webgpu_create_buffer(wgpu::Device &    device,
 
 /** WebGPU Actions */
 
-static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) {
-    switch (status) {
-        case wgpu::WaitStatus::Success:
-            return true;
-        case wgpu::WaitStatus::TimedOut:
-            if (allow_timeout) {
-                return false;
-            }
-            GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n");
-            return false;
-        case wgpu::WaitStatus::Error:
-            GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
-            return false;
-        default:
-            GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
-            return false;
-    }
-}
-
 #ifdef GGML_WEBGPU_GPU_PROFILE
-static void ggml_backend_webgpu_erase_completed_futures(std::vector<wgpu::FutureWaitInfo> & futures) {
-    futures.erase(std::remove_if(futures.begin(), futures.end(),
-                                 [](const wgpu::FutureWaitInfo & info) { return info.completed; }),
-                  futures.end());
-}
-
 static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context &             ctx,
-                                                     std::vector<wgpu::FutureWaitInfo> & futures,
-                                                     bool                                block) {
+                                                     std::vector<wgpu::FutureWaitInfo> & futures) {
     if (futures.empty()) {
         return;
     }
 
-    uint64_t timeout_ms = block ? UINT64_MAX : 0;
-    if (block) {
-        while (!futures.empty()) {
-            auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
-            if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
-                ggml_backend_webgpu_erase_completed_futures(futures);
-            }
-        }
-    } else {
-        auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
-        if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
-            ggml_backend_webgpu_erase_completed_futures(futures);
-        }
+    constexpr size_t max_futures_per_wait = 64;
+
+    while (!futures.empty()) {
+        ctx->instance.WaitAny(std::min(max_futures_per_wait, futures.size()), futures.data(), UINT64_MAX);
+        futures.erase(std::remove_if(futures.begin(), futures.end(),
+                                     [](const wgpu::FutureWaitInfo & info) { return info.completed; }),
+                      futures.end());
     }
 }
 #endif
 
-// Wait for the queue to finish processing all submitted work
-static void ggml_backend_webgpu_wait(webgpu_global_context &          ctx,
-                                     std::vector<webgpu_submission> & subs,
-                                     bool                             block = true) {
-    if (subs.empty()) {
-        return;
-    }
-
-    bool blocking_wait = block || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD;
-    while (blocking_wait) {
-        auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, WEBGPU_WAIT_ANY_TIMEOUT_MS * 1e6);
-        if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
-#ifdef GGML_WEBGPU_GPU_PROFILE
-            ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);
-#endif
-            subs.erase(subs.begin());
-        }
-        blocking_wait = (block && !subs.empty()) || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD;
-    }
-
-    if (subs.empty()) {
-        return;
-    }
-
-    // Poll each submit future once and remove completed submissions.
-    for (auto sub = subs.begin(); sub != subs.end();) {
-        auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0);
-        bool success    = ggml_backend_webgpu_handle_wait_status(waitStatus, true);
-#ifdef GGML_WEBGPU_GPU_PROFILE
-        ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false);
-        if (success && sub->profile_futures.empty()) {
-#else
-        if (success) {
-#endif
-            sub = subs.erase(sub);
-        } else {
-            ++sub;
-        }
-    }
+static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) {
+    ctx->instance.WaitAny(
+        ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
+                                       [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
+                                           if (status != wgpu::QueueWorkDoneStatus::Success) {
+                                               GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
+                                                              std::string(message).c_str());
+                                           }
+                                       }),
+        UINT64_MAX);
 }
 
 static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,
@@ -570,34 +460,10 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
 }
 #endif
 
-static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context &       ctx,
-                                                    std::vector<webgpu_command> & commands,
-                                                    webgpu_buf_pool &             param_buf_pool) {
-    std::vector<wgpu::CommandBuffer> command_buffers;
-    std::vector<wgpu::Buffer>        params_bufs;
-    webgpu_submission                submission;
-#ifdef GGML_WEBGPU_GPU_PROFILE
-    std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;
-#endif
-
-    for (const auto & command : commands) {
-        command_buffers.push_back(command.commands);
-        params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end());
-    }
-    ctx->queue.Submit(command_buffers.size(), command_buffers.data());
-
-    wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
-        wgpu::CallbackMode::AllowSpontaneous,
-        [&param_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
-            if (status != wgpu::QueueWorkDoneStatus::Success) {
-                GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
-            }
-            // Free the staged buffers
-            param_buf_pool.free_bufs(params_bufs);
-        });
-    submission.submit_done = { p_f };
-
 #ifdef GGML_WEBGPU_GPU_PROFILE
+static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context &             ctx,
+                                                        const std::vector<webgpu_command> & commands,
+                                                        std::vector<wgpu::FutureWaitInfo> & futures) {
     for (const auto & command : commands) {
         auto label   = command.pipeline_name;
         auto ts_bufs = command.timestamp_query_bufs;
@@ -616,15 +482,15 @@ static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context &
                 // We can't unmap in here due to WebGPU reentrancy limitations.
                 ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
             });
-        submission.profile_futures.push_back({ f });
+        futures.push_back({ f });
     }
-#endif
-    return submission;
 }
+#endif
 
-static webgpu_command ggml_backend_webgpu_build_multi(
+static webgpu_encoded_op ggml_backend_webgpu_build_multi(
     webgpu_global_context &                                ctx,
-    webgpu_buf_pool &                                      param_buf_pool,
+    webgpu_param_arena &                                   param_arena,
+    wgpu::CommandEncoder &                                 encoder,
     const std::vector<webgpu_pipeline> &                   pipelines,
     const std::vector<std::vector<uint32_t>> &             params_list,
     const std::vector<std::vector<wgpu::BindGroupEntry>> & bind_group_entries_list,
@@ -633,16 +499,21 @@ static webgpu_command ggml_backend_webgpu_build_multi(
     GGML_ASSERT(pipelines.size() == bind_group_entries_list.size());
     GGML_ASSERT(pipelines.size() == workgroups_list.size());
 
-    std::vector<wgpu::Buffer>    params_bufs_list;
+    webgpu_encoded_op            result = {};
     std::vector<wgpu::BindGroup> bind_groups;
+    std::vector<size_t>          param_offsets;
+    result.num_kernels = pipelines.size();
 
     for (size_t i = 0; i < pipelines.size(); i++) {
-        wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs();
+        const size_t param_size   = params_list[i].size() * sizeof(uint32_t);
+        const size_t param_offset = param_arena.alloc_slot(param_size);
 
         std::vector<wgpu::BindGroupEntry> entries            = bind_group_entries_list[i];
         uint32_t                          params_binding_num = entries.size();
-        entries.push_back(
-            { .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() });
+        entries.push_back({ .binding = params_binding_num,
+                            .buffer  = param_arena.buffer,
+                            .offset  = param_offset,
+                            .size    = param_arena.slot_size });
 
         wgpu::BindGroupDescriptor bind_group_desc;
         bind_group_desc.layout     = pipelines[i].pipeline.GetBindGroupLayout(0);
@@ -650,13 +521,12 @@ static webgpu_command ggml_backend_webgpu_build_multi(
         bind_group_desc.entries    = entries.data();
         bind_group_desc.label      = pipelines[i].name.c_str();
         bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc));
-
-        params_bufs_list.push_back(params_bufs);
+        param_offsets.push_back(param_offset);
     }
 
-    wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
-    for (size_t i = 0; i < params_bufs_list.size(); i++) {
-        ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t));
+    for (size_t i = 0; i < param_offsets.size(); i++) {
+        ctx->queue.WriteBuffer(param_arena.buffer, param_offsets[i], params_list[i].data(),
+                               params_list[i].size() * sizeof(uint32_t));
     }
 #ifdef GGML_WEBGPU_GPU_PROFILE
     webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();
@@ -682,29 +552,21 @@ static webgpu_command ggml_backend_webgpu_build_multi(
 #ifdef GGML_WEBGPU_GPU_PROFILE
     encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0);
     encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize());
-#endif
-
-    wgpu::CommandBuffer commands = encoder.Finish();
-    webgpu_command      result   = {};
-    result.commands              = commands;
-    result.params_bufs           = params_bufs_list;
-    result.num_kernels           = pipelines.size();
-#ifdef GGML_WEBGPU_GPU_PROFILE
     result.timestamp_query_bufs = ts_bufs;
-    // TODO: handle multiple pipeline names
     result.pipeline_name        = pipelines.front().name;
 #endif
     return result;
 }
 
-static webgpu_command ggml_backend_webgpu_build(webgpu_global_context &           ctx,
-                                                webgpu_buf_pool &                 param_buf_pool,
-                                                webgpu_pipeline &                 pipeline,
-                                                std::vector<uint32_t>             params,
-                                                std::vector<wgpu::BindGroupEntry> bind_group_entries,
-                                                uint32_t                          wg_x,
-                                                uint32_t                          wg_y = 1) {
-    return ggml_backend_webgpu_build_multi(ctx, param_buf_pool,
+static webgpu_encoded_op ggml_backend_webgpu_build(webgpu_global_context &           ctx,
+                                                   webgpu_param_arena &              param_arena,
+                                                   wgpu::CommandEncoder &            encoder,
+                                                   webgpu_pipeline &                 pipeline,
+                                                   std::vector<uint32_t>             params,
+                                                   std::vector<wgpu::BindGroupEntry> bind_group_entries,
+                                                   uint32_t                          wg_x,
+                                                   uint32_t                          wg_y = 1) {
+    return ggml_backend_webgpu_build_multi(ctx, param_arena, encoder,
                                            {
                                                pipeline
     },
@@ -724,10 +586,28 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
     size_t   bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread;
     uint32_t wg_x         = CEIL_DIV(size + 3, bytes_per_wg);
 
-    webgpu_command command =
-        ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
-    std::vector<webgpu_command>    commands = { command };
-    std::vector<webgpu_submission> sub      = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) };
+    ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t));
+
+    entries.push_back(
+        { .binding = 1, .buffer = ctx->memset_params_buf, .offset = 0, .size = WEBGPU_PARAMS_BUF_SIZE_BYTES });
+
+    wgpu::BindGroupDescriptor bind_group_desc;
+    bind_group_desc.layout     = ctx->memset_pipeline.pipeline.GetBindGroupLayout(0);
+    bind_group_desc.entryCount = entries.size();
+    bind_group_desc.entries    = entries.data();
+    bind_group_desc.label      = ctx->memset_pipeline.name.c_str();
+    wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
+
+    wgpu::CommandEncoder     encoder = ctx->device.CreateCommandEncoder();
+    wgpu::ComputePassEncoder pass    = encoder.BeginComputePass();
+    pass.SetPipeline(ctx->memset_pipeline.pipeline);
+    pass.SetBindGroup(0, bind_group);
+    pass.DispatchWorkgroups(wg_x, 1, 1);
+    pass.End();
+
+    wgpu::CommandBuffer              command  = encoder.Finish();
+    std::vector<wgpu::CommandBuffer> commands = { command };
+    ctx->queue.Submit(commands.size(), commands.data());
 }
 
 /** End WebGPU Actions */
@@ -840,7 +720,10 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0
     return flags;
 }
 
-static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context &       ctx,
+                                         wgpu::CommandEncoder & encoder,
+                                         ggml_tensor *          src,
+                                         ggml_tensor *          dst) {
     ggml_webgpu_shader_lib_context shader_lib_ctx = {
         .src0        = src,
         .dst         = dst,
@@ -878,10 +761,14 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g
     };
 
     uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x);
 }
 
-static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
+static webgpu_encoded_op ggml_webgpu_set(webgpu_context &       ctx,
+                                         wgpu::CommandEncoder & encoder,
+                                         ggml_tensor *          src0,
+                                         ggml_tensor *          src1,
+                                         ggml_tensor *          dst) {
     const bool inplace = ggml_webgpu_tensor_equal(src0, dst);
 
     ggml_webgpu_shader_lib_context shader_lib_ctx = {
@@ -940,10 +827,13 @@ static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0,
                         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
 
     uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x);
 }
 
-static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+static webgpu_encoded_op ggml_webgpu_pad(webgpu_context &       ctx,
+                                         wgpu::CommandEncoder & encoder,
+                                         ggml_tensor *          src,
+                                         ggml_tensor *          dst) {
     ggml_webgpu_shader_lib_context shader_lib_ctx = {
         .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
     };
@@ -995,13 +885,14 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g
     };
 
     uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x);
 }
 
-static webgpu_command ggml_webgpu_solve_tri(webgpu_context & ctx,
-                                            ggml_tensor *    src0,
-                                            ggml_tensor *    src1,
-                                            ggml_tensor *    dst) {
+static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context &       ctx,
+                                               wgpu::CommandEncoder & encoder,
+                                               ggml_tensor *          src0,
+                                               ggml_tensor *          src1,
+                                               ggml_tensor *          dst) {
     ggml_webgpu_shader_lib_context shader_lib_ctx = {
         .src0               = src0,
         .src1               = src1,
@@ -1056,13 +947,14 @@ static webgpu_command ggml_webgpu_solve_tri(webgpu_context & ctx,
 
     const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size);
     const uint32_t wg_y = (uint32_t) (dst->ne[2] * dst->ne[3]);
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y);
 }
 
-static webgpu_command ggml_webgpu_ssm_conv(webgpu_context & ctx,
-                                           ggml_tensor *    src0,
-                                           ggml_tensor *    src1,
-                                           ggml_tensor *    dst) {
+static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context &       ctx,
+                                              wgpu::CommandEncoder & encoder,
+                                              ggml_tensor *          src0,
+                                              ggml_tensor *          src1,
+                                              ggml_tensor *          dst) {
     ggml_webgpu_shader_lib_context shader_lib_ctx = {
         .src0        = src0,
         .src1        = src1,
@@ -1112,17 +1004,18 @@ static webgpu_command ggml_webgpu_ssm_conv(webgpu_context & ctx,
 
     const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size);
     const uint32_t wg_y = token_tiles * (uint32_t) dst->ne[2];
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
-}
-
-static webgpu_command ggml_webgpu_gated_delta_net(webgpu_context & ctx,
-                                                  ggml_tensor *    src0,
-                                                  ggml_tensor *    src1,
-                                                  ggml_tensor *    src2,
-                                                  ggml_tensor *    src3,
-                                                  ggml_tensor *    src4,
-                                                  ggml_tensor *    src5,
-                                                  ggml_tensor *    dst) {
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y);
+}
+
+static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context &       ctx,
+                                                     wgpu::CommandEncoder & encoder,
+                                                     ggml_tensor *          src0,
+                                                     ggml_tensor *          src1,
+                                                     ggml_tensor *          src2,
+                                                     ggml_tensor *          src3,
+                                                     ggml_tensor *          src4,
+                                                     ggml_tensor *          src5,
+                                                     ggml_tensor *          dst) {
     ggml_webgpu_shader_lib_context shader_lib_ctx = {
         .src0        = src0,
         .src1        = src1,
@@ -1197,13 +1090,14 @@ static webgpu_command ggml_webgpu_gated_delta_net(webgpu_context & ctx,
          .size    = ggml_webgpu_tensor_binding_size(ctx, dst)  }
     };
 
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, h, n_seqs);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, h, n_seqs);
 }
 
-static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
-                                                          ggml_tensor *    src,
-                                                          ggml_tensor *    idx,
-                                                          ggml_tensor *    dst) {
+static std::optional<webgpu_encoded_op> ggml_webgpu_set_rows(webgpu_context &       ctx,
+                                                             wgpu::CommandEncoder & encoder,
+                                                             ggml_tensor *          src,
+                                                             ggml_tensor *          idx,
+                                                             ggml_tensor *          dst) {
     // For set rows specifically, we need to check if src and idx are empty
     // tensors.
     if (ggml_is_empty(src) || ggml_is_empty(idx)) {
@@ -1266,7 +1160,7 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
         threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
     }
     uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size);
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, 1);
 }
 
 // Workgroup size is a common constant
@@ -1277,10 +1171,11 @@ static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_si
     return constants;
 }
 
-static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
-                                           ggml_tensor *    src,
-                                           ggml_tensor *    idx,
-                                           ggml_tensor *    dst) {
+static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context &       ctx,
+                                              wgpu::CommandEncoder & encoder,
+                                              ggml_tensor *          src,
+                                              ggml_tensor *          idx,
+                                              ggml_tensor *          dst) {
     const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32;
 
     ggml_webgpu_shader_lib_context shader_lib_ctx = {
@@ -1332,13 +1227,14 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
     uint32_t total_threads  = float_parallel ? blocks_per_row * total_rows : total_rows;
     uint32_t wg_x           = CEIL_DIV(total_threads, decisions->wg_size);
 
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x);
 }
 
-static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
-                                          ggml_tensor *    src0,
-                                          ggml_tensor *    src1,
-                                          ggml_tensor *    dst) {
+static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context &       ctx,
+                                             wgpu::CommandEncoder & encoder,
+                                             ggml_tensor *          src0,
+                                             ggml_tensor *          src1,
+                                             ggml_tensor *          dst) {
     // Determine if this is a mat-vec operation
     bool is_vec = (dst->ne[1] == 1);
 
@@ -1477,16 +1373,18 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
         compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
     }
 
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y);
 }
 
-static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
-                                             ggml_tensor *    Q,
-                                             ggml_tensor *    K,
-                                             ggml_tensor *    V,
-                                             ggml_tensor *    mask,
-                                             ggml_tensor *    sinks,
-                                             ggml_tensor *    dst) {
+#ifndef __EMSCRIPTEN__
+static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context &       ctx,
+                                                wgpu::CommandEncoder & encoder,
+                                                ggml_tensor *          Q,
+                                                ggml_tensor *          K,
+                                                ggml_tensor *          V,
+                                                ggml_tensor *          mask,
+                                                ggml_tensor *          sinks,
+                                                ggml_tensor *          dst) {
     float scale = *(float *) dst->op_params;
     float max_bias;
     memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
@@ -1575,9 +1473,8 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
         K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
     const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported &&
                          (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type);
-    const uint32_t vec_nwg_cap =
-        std::max(1u, std::min<uint32_t>(32u, ctx->global_ctx->capabilities.max_subgroup_size));
-    const bool use_blk = use_vec && has_mask;
+    const uint32_t vec_nwg_cap = std::max(1u, std::min<uint32_t>(32u, ctx->global_ctx->capabilities.max_subgroup_size));
+    const bool     use_blk     = use_vec && has_mask;
 
     ggml_webgpu_flash_attn_pipeline_key key = {
         .kv_type            = K->type,
@@ -1656,9 +1553,9 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
         if (use_blk) {
             GGML_ASSERT(has_mask);
 
-            blk_nblk0       = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile);
-            blk_nblk1       = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile);
-            blk_buf         = ggml_webgpu_tensor_buf(dst);
+            blk_nblk0                   = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile);
+            blk_nblk1                   = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile);
+            blk_buf                     = ggml_webgpu_tensor_buf(dst);
             const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type));
             blk_batch_count             = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u;
             const uint64_t blk_elems    = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
@@ -1729,8 +1626,10 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
                                       .size    = ggml_webgpu_tensor_binding_size(ctx, sinks) });
         }
         if (use_blk) {
-            split_entries.push_back(
-                { .binding = split_binding_index++, .buffer = blk_buf, .offset = blk_entries[1].offset, .size = blk_size_bytes });
+            split_entries.push_back({ .binding = split_binding_index++,
+                                      .buffer  = blk_buf,
+                                      .offset  = blk_entries[1].offset,
+                                      .size    = blk_size_bytes });
         }
         split_entries.push_back(
             { .binding = split_binding_index++, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_bind_size });
@@ -1799,14 +1698,18 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
             workgroups_list.push_back({ (uint32_t) nrows, 1u });
         }
 
-        return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list,
+        return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list,
                                                entries_list, workgroups_list);
     }
 
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x);
 }
+#endif  // __EMSCRIPTEN__
 
-static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context &       ctx,
+                                              wgpu::CommandEncoder & encoder,
+                                              ggml_tensor *          src,
+                                              ggml_tensor *          dst) {
     bool is_unary = dst->op == GGML_OP_UNARY;
     bool inplace  = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL);
 
@@ -1881,13 +1784,14 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s
     }
 
     uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x);
 }
 
-static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
-                                            ggml_tensor *    src0,
-                                            ggml_tensor *    src1,
-                                            ggml_tensor *    dst) {
+static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context &       ctx,
+                                               wgpu::CommandEncoder & encoder,
+                                               ggml_tensor *          src0,
+                                               ggml_tensor *          src1,
+                                               ggml_tensor *          dst) {
     binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
 
     ggml_webgpu_shader_lib_context shader_lib_ctx = {
@@ -1983,13 +1887,14 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
     }
 
     uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x);
 }
 
-static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
-                                         ggml_tensor *    src0,
-                                         ggml_tensor *    src1,
-                                         ggml_tensor *    dst) {
+static webgpu_encoded_op ggml_webgpu_concat(webgpu_context &       ctx,
+                                            wgpu::CommandEncoder & encoder,
+                                            ggml_tensor *          src0,
+                                            ggml_tensor *          src1,
+                                            ggml_tensor *          dst) {
     uint32_t ne  = (uint32_t) ggml_nelements(dst);
     uint32_t dim = (uint32_t) dst->op_params[0];
 
@@ -2039,10 +1944,13 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
     webgpu_pipeline pipeline  = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
     auto *          decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
     uint32_t        wg_x      = CEIL_DIV(ne, decisions->wg_size);
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x);
 }
 
-static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) {
+static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context &       ctx,
+                                            wgpu::CommandEncoder & encoder,
+                                            ggml_tensor *          src0,
+                                            ggml_tensor *          dst) {
     uint32_t ne = (uint32_t) ggml_nelements(dst);
 
     std::vector<uint32_t> params = { ne,
@@ -2081,10 +1989,13 @@ static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src
     webgpu_pipeline pipeline  = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx);
     auto *          decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
     uint32_t        wg_x      = CEIL_DIV(ne, decisions->wg_size);
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x);
 }
 
-static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context &       ctx,
+                                              wgpu::CommandEncoder & encoder,
+                                              ggml_tensor *          src,
+                                              ggml_tensor *          dst) {
     bool inplace = ggml_webgpu_tensor_equal(src, dst);
 
     std::vector<uint32_t> params = {
@@ -2124,14 +2035,16 @@ static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * s
     };
 
     webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx);
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(src));
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries,
+                                     ggml_nrows(src));
 }
 
-static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
-                                       ggml_tensor *    src0,
-                                       ggml_tensor *    src1,
-                                       ggml_tensor *    src2,
-                                       ggml_tensor *    dst) {
+static webgpu_encoded_op ggml_webgpu_rope(webgpu_context &       ctx,
+                                          wgpu::CommandEncoder & encoder,
+                                          ggml_tensor *          src0,
+                                          ggml_tensor *          src1,
+                                          ggml_tensor *          src2,
+                                          ggml_tensor *          dst) {
     ggml_webgpu_shader_lib_context shader_lib_ctx = {
         .src0        = src0,
         .src1        = src1,
@@ -2228,10 +2141,14 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
     }
 
     uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x);
 }
 
-static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
+static webgpu_encoded_op ggml_webgpu_glu(webgpu_context &       ctx,
+                                         wgpu::CommandEncoder & encoder,
+                                         ggml_tensor *          src0,
+                                         ggml_tensor *          src1,
+                                         ggml_tensor *          dst) {
     ggml_webgpu_shader_lib_context shader_lib_ctx = {
         .src0        = src0,
         .src1        = src1,
@@ -2290,10 +2207,13 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0,
                         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
 
     uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x);
 }
 
-static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+static webgpu_encoded_op ggml_webgpu_scale(webgpu_context &       ctx,
+                                           wgpu::CommandEncoder & encoder,
+                                           ggml_tensor *          src,
+                                           ggml_tensor *          dst) {
     bool inplace = ggml_webgpu_tensor_equal(src, dst);
 
     ggml_webgpu_shader_lib_context shader_lib_ctx = {
@@ -2341,14 +2261,15 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src,
     }
 
     uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x);
 }
 
-static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
-                                           ggml_tensor *    src0,
-                                           ggml_tensor *    src1,
-                                           ggml_tensor *    src2,
-                                           ggml_tensor *    dst) {
+static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context &       ctx,
+                                              wgpu::CommandEncoder & encoder,
+                                              ggml_tensor *          src0,
+                                              ggml_tensor *          src1,
+                                              ggml_tensor *          src2,
+                                              ggml_tensor *          dst) {
     ggml_webgpu_shader_lib_context shader_lib_ctx = {
         .src0        = src0,
         .src1        = src1,
@@ -2424,10 +2345,14 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
                             .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
     }
 
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(dst));
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries,
+                                     ggml_nrows(dst));
 }
 
-static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context &       ctx,
+                                            wgpu::CommandEncoder & encoder,
+                                            ggml_tensor *          src,
+                                            ggml_tensor *          dst) {
     std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
                                      (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
                                      (uint32_t) src->ne[0] };
@@ -2449,10 +2374,13 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src
 
     webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx);
     uint32_t        wg_x     = ggml_nelements(dst);
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x);
 }
 
-static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context &       ctx,
+                                             wgpu::CommandEncoder & encoder,
+                                             ggml_tensor *          src,
+                                             ggml_tensor *          dst) {
     bool is_top_k = dst->op == GGML_OP_TOP_K;
 
     ggml_webgpu_shader_lib_context shader_lib_ctx = {
@@ -2543,7 +2471,7 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr
     workgroups_list.push_back({ wg_x_init, wg_y_init });
 
     if (merge_passes == 0) {
-        return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list,
+        return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list,
                                                entries_list, workgroups_list);
     }
 
@@ -2605,11 +2533,14 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr
         in_is_tmp = !in_is_tmp;
     }
 
-    return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list,
-                                           workgroups_list);
+    return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list,
+                                           entries_list, workgroups_list);
 }
 
-static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+static webgpu_encoded_op ggml_webgpu_cumsum(webgpu_context &       ctx,
+                                            wgpu::CommandEncoder & encoder,
+                                            ggml_tensor *          src,
+                                            ggml_tensor *          dst) {
     std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
                                      (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
                                      (uint32_t) src->ne[0] };
@@ -2634,10 +2565,13 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src
 
     webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx);
     uint32_t        wg_x     = ggml_nrows(dst);
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x);
 }
 
-static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context &       ctx,
+                                              wgpu::CommandEncoder & encoder,
+                                              ggml_tensor *          src,
+                                              ggml_tensor *          dst) {
     bool                  total_sum = dst->op == GGML_OP_SUM;
     std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
                                      (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
@@ -2666,11 +2600,13 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s
     webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx);
 
     uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst);
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x);
 }
 
 // Returns the encoded command, or std::nullopt if the operation is a no-op
-static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
+static std::optional<webgpu_encoded_op> ggml_webgpu_encode_node(webgpu_context         ctx,
+                                                                wgpu::CommandEncoder & encoder,
+                                                                ggml_tensor *          node) {
     if (ggml_is_empty(node)) {
         return std::nullopt;
     }
@@ -2693,18 +2629,18 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
             return std::nullopt;
         case GGML_OP_CPY:
         case GGML_OP_CONT:
-            return ggml_webgpu_cpy(ctx, src0, node);
+            return ggml_webgpu_cpy(ctx, encoder, src0, node);
         case GGML_OP_SET:
-            return ggml_webgpu_set(ctx, src0, src1, node);
+            return ggml_webgpu_set(ctx, encoder, src0, src1, node);
         case GGML_OP_SET_ROWS:
-            return ggml_webgpu_set_rows(ctx, src0, src1, node);
+            return ggml_webgpu_set_rows(ctx, encoder, src0, src1, node);
         case GGML_OP_GET_ROWS:
-            return ggml_webgpu_get_rows(ctx, src0, src1, node);
+            return ggml_webgpu_get_rows(ctx, encoder, src0, src1, node);
         case GGML_OP_MUL_MAT:
-            return ggml_webgpu_mul_mat(ctx, src0, src1, node);
+            return ggml_webgpu_mul_mat(ctx, encoder, src0, src1, node);
         case GGML_OP_FLASH_ATTN_EXT:
 #ifndef __EMSCRIPTEN__
-            return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
+            return ggml_webgpu_flash_attn(ctx, encoder, src0, src1, src2, node->src[3], node->src[4], node);
 #else
             return std::nullopt;
 #endif
@@ -2712,22 +2648,22 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
         case GGML_OP_SUB:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
-            return ggml_webgpu_binary_op(ctx, src0, src1, node);
+            return ggml_webgpu_binary_op(ctx, encoder, src0, src1, node);
         case GGML_OP_CONCAT:
-            return ggml_webgpu_concat(ctx, src0, src1, node);
+            return ggml_webgpu_concat(ctx, encoder, src0, src1, node);
         case GGML_OP_REPEAT:
-            return ggml_webgpu_repeat(ctx, src0, node);
+            return ggml_webgpu_repeat(ctx, encoder, src0, node);
         case GGML_OP_RMS_NORM:
         case GGML_OP_L2_NORM:
-            return ggml_webgpu_row_norm(ctx, src0, node);
+            return ggml_webgpu_row_norm(ctx, encoder, src0, node);
         case GGML_OP_ROPE:
-            return ggml_webgpu_rope(ctx, src0, src1, src2, node);
+            return ggml_webgpu_rope(ctx, encoder, src0, src1, src2, node);
         case GGML_OP_GLU:
-            return ggml_webgpu_glu(ctx, src0, src1, node);
+            return ggml_webgpu_glu(ctx, encoder, src0, src1, node);
         case GGML_OP_SCALE:
-            return ggml_webgpu_scale(ctx, src0, node);
+            return ggml_webgpu_scale(ctx, encoder, src0, node);
         case GGML_OP_SOFT_MAX:
-            return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
+            return ggml_webgpu_soft_max(ctx, encoder, src0, src1, src2, node);
         case GGML_OP_UNARY:
         case GGML_OP_CLAMP:
         case GGML_OP_FILL:
@@ -2738,26 +2674,27 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
         case GGML_OP_COS:
         case GGML_OP_DIAG:
         case GGML_OP_TRI:
-            return ggml_webgpu_unary_op(ctx, src0, node);
+            return ggml_webgpu_unary_op(ctx, encoder, src0, node);
         case GGML_OP_SOLVE_TRI:
-            return ggml_webgpu_solve_tri(ctx, src0, src1, node);
+            return ggml_webgpu_solve_tri(ctx, encoder, src0, src1, node);
         case GGML_OP_SSM_CONV:
-            return ggml_webgpu_ssm_conv(ctx, src0, src1, node);
+            return ggml_webgpu_ssm_conv(ctx, encoder, src0, src1, node);
         case GGML_OP_GATED_DELTA_NET:
-            return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node);
+            return ggml_webgpu_gated_delta_net(ctx, encoder, src0, src1, src2, node->src[3], node->src[4], node->src[5],
+                                               node);
         case GGML_OP_PAD:
-            return ggml_webgpu_pad(ctx, src0, node);
+            return ggml_webgpu_pad(ctx, encoder, src0, node);
         case GGML_OP_ARGMAX:
-            return ggml_webgpu_argmax(ctx, src0, node);
+            return ggml_webgpu_argmax(ctx, encoder, src0, node);
         case GGML_OP_ARGSORT:
         case GGML_OP_TOP_K:
             // we reuse the same argsort implementation for top_k
-            return ggml_webgpu_argsort(ctx, src0, node);
+            return ggml_webgpu_argsort(ctx, encoder, src0, node);
         case GGML_OP_CUMSUM:
-            return ggml_webgpu_cumsum(ctx, src0, node);
+            return ggml_webgpu_cumsum(ctx, encoder, src0, node);
         case GGML_OP_SUM:
         case GGML_OP_SUM_ROWS:
-            return ggml_webgpu_sum_rows(ctx, src0, node);
+            return ggml_webgpu_sum_rows(ctx, encoder, src0, node);
         default:
             return std::nullopt;
     }
@@ -2771,30 +2708,42 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
 
     WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
 
-    std::vector<webgpu_command>    commands;
-    std::vector<webgpu_submission> subs;
-    uint32_t                       num_batched_kernels = 0;
-    bool                           contains_set_rows   = false;
+    std::vector<webgpu_encoded_op> commands;
+#ifdef GGML_WEBGPU_GPU_PROFILE
+    std::vector<wgpu::FutureWaitInfo> profile_futures;
+#endif
+    uint32_t             num_batched_kernels = 0;
+    bool                 contains_set_rows   = false;
+    wgpu::CommandEncoder batch_encoder       = ctx->global_ctx->device.CreateCommandEncoder();
+
     for (int i = 0; i < cgraph->n_nodes; i++) {
         if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
             contains_set_rows = true;
         }
-        if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
+        if (auto cmd = ggml_webgpu_encode_node(ctx, batch_encoder, cgraph->nodes[i])) {
             commands.push_back(*cmd);
             num_batched_kernels += cmd.value().num_kernels;
         }
 
         if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
-            num_batched_kernels = 0;
-            subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
-            // Process events and check for completed submissions
-            ctx->global_ctx->instance.ProcessEvents();
-            ggml_backend_webgpu_wait(ctx->global_ctx, subs, false);
+            num_batched_kernels                = 0;
+            wgpu::CommandBuffer batch_commands = batch_encoder.Finish();
+            ctx->global_ctx->queue.Submit(1, &batch_commands);
+#ifdef GGML_WEBGPU_GPU_PROFILE
+            ggml_backend_webgpu_collect_profile_futures(ctx->global_ctx, commands, profile_futures);
+#endif
+            ctx->param_arena.reset();
             commands.clear();
+            batch_encoder = ctx->global_ctx->device.CreateCommandEncoder();
         }
     }
     if (!commands.empty()) {
-        subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
+        wgpu::CommandBuffer batch_commands = batch_encoder.Finish();
+        ctx->global_ctx->queue.Submit(1, &batch_commands);
+#ifdef GGML_WEBGPU_GPU_PROFILE
+        ggml_backend_webgpu_collect_profile_futures(ctx->global_ctx, commands, profile_futures);
+#endif
+        ctx->param_arena.reset();
         commands.clear();
     }
 
@@ -2805,6 +2754,11 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
                                    ctx->set_rows_host_error_buf.GetSize());
         wgpu::CommandBuffer set_rows_commands = encoder.Finish();
         ctx->global_ctx->queue.Submit(1, &set_rows_commands);
+    }
+
+    ggml_backend_webgpu_wait_queue(ctx->global_ctx);
+
+    if (contains_set_rows) {
         ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0,
                                        ctx->set_rows_host_error_buf.GetSize());
         const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange();
@@ -2814,7 +2768,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
         ctx->set_rows_host_error_buf.Unmap();
     }
 
-    ggml_backend_webgpu_wait(ctx->global_ctx, subs);
+#ifdef GGML_WEBGPU_GPU_PROFILE
+    ggml_backend_webgpu_wait_profile_futures(ctx->global_ctx, profile_futures);
+#endif
     WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
     return GGML_STATUS_SUCCESS;
 }
@@ -3063,18 +3019,16 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
                                            (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
                     const bool kv_vec_type_supported =
                         K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
-                    const bool use_vec =
-                        (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported &&
-                        (V->type == K->type);
+                    const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) &&
+                                         kv_vec_type_supported && (V->type == K->type);
                     if (use_vec) {
                         const uint32_t sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
                         const uint32_t sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
                         const size_t   limit_bytes =
                             ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
-                        const size_t q_tile = sg_mat_m;
-                        const size_t base_q_bytes =
-                            (Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
-                            2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
+                        const size_t q_tile       = sg_mat_m;
+                        const size_t base_q_bytes = (Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
+                                                    2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
                         size_t bytes_per_kv = 0;
                         if (!kv_direct) {
                             bytes_per_kv += std::max(Q->ne[0], V->ne[0]);
@@ -3084,10 +3038,9 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
                         }
                         bytes_per_kv += q_tile;
                         bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
-                        uint32_t kv_tile =
-                            ((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n;
-                        kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile));
-                        kv_tile = (kv_tile / sg_mat_n) * sg_mat_n;
+                        uint32_t kv_tile = ((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n;
+                        kv_tile          = std::max(sg_mat_n, std::min(32u, kv_tile));
+                        kv_tile          = (kv_tile / sg_mat_n) * sg_mat_n;
                         if (kv_direct) {
                             GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
                             while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
@@ -3097,30 +3050,30 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
 
                         const uint32_t vec_nwg_cap = std::max(
                             1u, std::min<uint32_t>(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size));
-                        uint32_t       nwg         = 1u;
-                        const uint64_t kv_span     = (uint64_t) std::max(1u, kv_tile);
+                        uint32_t       nwg     = 1u;
+                        const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile);
                         while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
                             nwg <<= 1;
                         }
                         nwg = std::min(nwg, vec_nwg_cap);
 
-                        const size_t align = ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
+                        const size_t align =
+                            ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
                         const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
                         if (nwg > 1u) {
                             const uint64_t tmp_data_elems  = nrows * (uint64_t) V->ne[0] * nwg;
                             const uint64_t tmp_stats_elems = nrows * 2u * nwg;
-                            const size_t tmp_size_bytes = ROUNDUP_POW2(
+                            const size_t   tmp_size_bytes  = ROUNDUP_POW2(
                                 (tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT);
                             res += tmp_size_bytes + align;
                         }
                         if (mask != nullptr) {
-                            const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile);
-                            const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u);
-                            const uint32_t stride_mask3 =
-                                (uint32_t) (mask->nb[3] / ggml_type_size(mask->type));
+                            const uint32_t blk_nblk0       = CEIL_DIV((uint32_t) K->ne[1], kv_tile);
+                            const uint32_t blk_nblk1       = CEIL_DIV((uint32_t) Q->ne[1], 1u);
+                            const uint32_t stride_mask3    = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type));
                             const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u;
-                            const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
-                            const size_t blk_size_bytes =
+                            const uint64_t blk_elems       = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
+                            const size_t   blk_size_bytes =
                                 ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
                             res += blk_size_bytes + align;
                         }
@@ -3195,11 +3148,11 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
     ctx->capabilities.memset_bytes_per_thread =
         CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads);
     std::vector<wgpu::ConstantEntry> constants(2);
-    constants[0].key         = "wg_size";
-    constants[0].value       = WEBGPU_MAX_WG_SIZE;
-    constants[1].key         = "bytes_per_thread";
-    constants[1].value       = ctx->capabilities.memset_bytes_per_thread;
-    ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
+    constants[0].key     = "wg_size";
+    constants[0].value   = WEBGPU_MAX_WG_SIZE;
+    constants[1].key     = "bytes_per_thread";
+    constants[1].value   = ctx->capabilities.memset_bytes_per_thread;
+    ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
 }
 
 static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
@@ -3331,9 +3284,9 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
     GGML_ASSERT(ctx->webgpu_global_ctx->device != nullptr);
 
     ggml_webgpu_init_memset_pipeline(ctx->webgpu_global_ctx);
-    ctx->webgpu_global_ctx->memset_buf_pool.init(ctx->webgpu_global_ctx->device, 1, WEBGPU_PARAMS_BUF_SIZE_BYTES,
-                                                 wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
-                                                 wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
+    ggml_webgpu_create_buffer(ctx->webgpu_global_ctx->device, ctx->webgpu_global_ctx->memset_params_buf,
+                              WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
+                              "memset_params_buf");
     ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue();
 
 #ifdef GGML_WEBGPU_GPU_PROFILE
@@ -3357,9 +3310,8 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
     webgpu_context                       webgpu_ctx = std::make_shared<webgpu_context_struct>();
     webgpu_ctx->global_ctx                          = dev_ctx->webgpu_global_ctx;
     webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
-    webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
-                                    wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
-                                    wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true);
+    webgpu_ctx->param_arena.init(webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES, WEBGPU_NUM_PARAM_SLOTS,
+                                 webgpu_ctx->global_ctx->capabilities.limits.minUniformBufferOffsetAlignment);
     ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf,
                               WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
                               wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf");