]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml webgpu: fix workgroup dispatch limit for large batch sizes (#19965)
authorAbhijit Ramesh <redacted>
Tue, 3 Mar 2026 03:35:11 +0000 (19:35 -0800)
committerGitHub <redacted>
Tue, 3 Mar 2026 03:35:11 +0000 (19:35 -0800)
* ggml-webgpu: fix workgroup dispatch limit for large batch sizes

WebGPU limits workgroup sizes to 65535 per dimension. Large MUL_MAT
operations with batch sizes exceedeing this limi would fail.

* add compute_2d_workgroups() helper to split total workgroup ID across
X/Y dimensions

* update mul_mat_reg_tile.wgsl to reconstruct linear workgroup ID from 2D
   dispatch

* update mul_mat_subgroup_matrix.wgsl to reconstruct linear workgroup ID
  from 2D dispatch

* update mul_mat.wgsl to compute global index from 2D workgroup
  coordinates

* refactor all three mul_mat dispatch paths to use the shared helper

* ggml-webgpu: add bounds checking for over-dispatched workgroups

2D workgroup dispatch can over-dispatch when total workgroups don't
divide evenly into the 65535 per-dimension limit. Extra workgroups
would compute invalid batch indices, causing memory corruption.

* add batch_idx bound check to mul_mat_reg_tile.wgsl and
mul_mat_subgroup_matrix.wgsl to prevent over-dispatched workgroups
from accessing invalid memory

* fixes test failures with large batch sizes (eg., bs=[128, 1024])

* ggml-webgpu: add back TODO for spliting large sizes into batches

* Optimize 2d workgroup provisioning

* Set some parameters that increase speed

---------

Co-authored-by: Reese Levine <redacted>
ggml/src/ggml-webgpu/ggml-webgpu.cpp
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl
ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl

index 913cf7f8825ebe47d526213e0ff3319ae437d44f..19451618ec5c8d1910fae50b2fa8f1ed326f84e8 100644 (file)
 #define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
 #define CEIL_DIV(M, N)        (((M) + (N) - 1) / (N))
 
+// Return a rectangular grid of workgroups with minimal over-provisioned workgroups.
+// Assumes that the total number of workgroups does not exceed max_per_dim^2.
+static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim, uint32_t & wg_x, uint32_t & wg_y) {
+    wg_y = std::max(1u, CEIL_DIV(total_wg, max_per_dim));
+    wg_x = CEIL_DIV(total_wg, wg_y);
+}
+
 #ifdef GGML_WEBGPU_DEBUG
 #    define WEBGPU_LOG_DEBUG(msg)  std::cout << msg << std::endl
 #    define WEBGPU_DEBUG_BUF_ELEMS 512
@@ -69,8 +76,8 @@
 
 /* Constants */
 
-#define WEBGPU_NUM_PARAM_BUFS                16u
-#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE     8u
+#define WEBGPU_NUM_PARAM_BUFS                48u
+#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE     16u
 #define WEBGPU_WAIT_ANY_TIMEOUT_MS           0
 // Maximum number of in-flight submissions per-thread, to avoid exhausting the
 // parameter buffer pool
@@ -1146,8 +1153,9 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
     };
 
     // Calculate workgroup dimensions
-    uint32_t wg_x = 1;
-    uint32_t wg_y = 1;
+    uint32_t       wg_x = 1;
+    uint32_t       wg_y = 1;
+    const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
 
     if (use_fast && is_vec) {
         auto decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
@@ -1155,9 +1163,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
         uint32_t batches       = dst->ne[2] * dst->ne[3];
         uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
         uint32_t total_wg      = output_groups * batches;
-        // TODO: split large sizes into multiple batches to avoid way over-provisioning workgroups
-        wg_x = std::min(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
-        wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
+        compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
     } else if (use_fast) {
         auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
 
@@ -1176,12 +1182,14 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
             wg_m              = CEIL_DIV(dst->ne[0], tile_m_s);
             wg_n              = CEIL_DIV(dst->ne[1], tile_n_s);
         }
-        wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
+        uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3];
+        compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
+
     } else {  // legacy
         auto     decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
         uint32_t wg_size   = decisions->wg_size;
-        wg_x               = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
-        wg_y               = 1;
+        uint32_t total_wg  = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
+        compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
     }
 
     return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
index 6aba47317c6d27ce477da894c22ff32bd8fdc2ad..5b9f5b362249706ba54bf640d221d9ba61207f06 100644 (file)
@@ -679,19 +679,24 @@ struct MulMatParams {
 @group(0) @binding(3) var<uniform> params: MulMatParams;
 
 @compute @workgroup_size(256)
-fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
+fn main(@builtin(local_invocation_id) local_id: vec3<u32>,
+        @builtin(workgroup_id) wg_id: vec3<u32>,
+        @builtin(num_workgroups) num_wg: vec3<u32>) {
+    let wg_linear = wg_id.y * num_wg.x + wg_id.x;
+    let global_idx = wg_linear * 256u + local_id.x;
+
     let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
-    if (global_id.x >= total) {
+    if (global_idx >= total) {
         return;
     }
 
     let dst2_stride = params.m * params.n;
     let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
 
-    let dst3_idx = global_id.x / dst3_stride;
+    let dst3_idx = global_idx / dst3_stride;
     let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
     let src13_idx = dst3_idx; // src1 is not broadcast
-    let dst3_rem = global_id.x % dst3_stride;
+    let dst3_rem = global_idx % dst3_stride;
 
     let dst2_idx = dst3_rem / dst2_stride;
     let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
index 771e5cd1ee388fc3f45bccf870a0e443c80e9782..761e3017c14cbfdf66b3b698a150f61ce9a3501a 100644 (file)
@@ -54,7 +54,8 @@ var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
 
 @compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
 fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
-        @builtin(local_invocation_id) local_id: vec3<u32>) {
+        @builtin(local_invocation_id) local_id: vec3<u32>,
+        @builtin(num_workgroups) num_wg: vec3<u32>) {
 
     let thread_id = local_id.x;
     let local_m = get_local_m(thread_id);
@@ -64,9 +65,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
     let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M);
     let wg_per_matrix = wg_m_count * wg_n_count;
 
-    let batch_idx = wg_id.x / wg_per_matrix;
+    let wg_linear = wg_id.y * num_wg.x + wg_id.x;
 
-    let wg_in_batch = wg_id.x % wg_per_matrix;
+    let batch_idx = wg_linear / wg_per_matrix;
+
+    let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
+    if (batch_idx >= total_batches) {
+        return;
+    }
+
+    let wg_in_batch = wg_linear % wg_per_matrix;
     let wg_m = wg_in_batch % wg_m_count;
     let wg_n = wg_in_batch / wg_m_count;
 
index 64529e03cdc065b78e46b5636bd69f427ac30cb6..9f9ef279f29686f9def009a0c9b56ace18996a08 100644 (file)
@@ -69,7 +69,8 @@ var<workgroup> shmem: array<f16, SHMEM_SIZE>;
 @compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
 fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
         @builtin(local_invocation_id) local_id: vec3<u32>,
-        @builtin(subgroup_id) subgroup_id: u32) {
+        @builtin(subgroup_id) subgroup_id: u32,
+        @builtin(num_workgroups) num_wg: vec3<u32>) {
 
     let thread_id = local_id.x;
     let subgroup_m = subgroup_id % SUBGROUP_M;
@@ -79,9 +80,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
     let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE;
     let wg_per_matrix = wg_m_count * wg_n_count;
 
-    let batch_idx = wg_id.x / wg_per_matrix;
+    let wg_linear = wg_id.y * num_wg.x + wg_id.x;
 
-    let wg_in_batch = wg_id.x % wg_per_matrix;
+    let batch_idx = wg_linear / wg_per_matrix;
+
+    let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
+    if (batch_idx >= total_batches) {
+        return;
+    }
+
+    let wg_in_batch = wg_linear % wg_per_matrix;
     let wg_m = wg_in_batch % wg_m_count;
     let wg_n = wg_in_batch / wg_m_count;