#define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
+// Return a rectangular grid of workgroups with minimal over-provisioned workgroups.
+// Assumes that the total number of workgroups does not exceed max_per_dim^2.
+static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim, uint32_t & wg_x, uint32_t & wg_y) {
+ wg_y = std::max(1u, CEIL_DIV(total_wg, max_per_dim));
+ wg_x = CEIL_DIV(total_wg, wg_y);
+}
+
#ifdef GGML_WEBGPU_DEBUG
# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
# define WEBGPU_DEBUG_BUF_ELEMS 512
/* Constants */
-#define WEBGPU_NUM_PARAM_BUFS 16u
-#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u
+#define WEBGPU_NUM_PARAM_BUFS 48u
+#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16u
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
// Maximum number of in-flight submissions per-thread, to avoid exhausting the
// parameter buffer pool
};
// Calculate workgroup dimensions
- uint32_t wg_x = 1;
- uint32_t wg_y = 1;
+ uint32_t wg_x = 1;
+ uint32_t wg_y = 1;
+ const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
if (use_fast && is_vec) {
auto decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
uint32_t batches = dst->ne[2] * dst->ne[3];
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
uint32_t total_wg = output_groups * batches;
- // TODO: split large sizes into multiple batches to avoid way over-provisioning workgroups
- wg_x = std::min(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
- wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
+ compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
} else if (use_fast) {
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
}
- wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
+ uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3];
+ compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
+
} else { // legacy
auto decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_size = decisions->wg_size;
- wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
- wg_y = 1;
+ uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
+ compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
}
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
@group(0) @binding(3) var<uniform> params: MulMatParams;
@compute @workgroup_size(256)
-fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
+fn main(@builtin(local_invocation_id) local_id: vec3<u32>,
+ @builtin(workgroup_id) wg_id: vec3<u32>,
+ @builtin(num_workgroups) num_wg: vec3<u32>) {
+ let wg_linear = wg_id.y * num_wg.x + wg_id.x;
+ let global_idx = wg_linear * 256u + local_id.x;
+
let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
- if (global_id.x >= total) {
+ if (global_idx >= total) {
return;
}
let dst2_stride = params.m * params.n;
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
- let dst3_idx = global_id.x / dst3_stride;
+ let dst3_idx = global_idx / dst3_stride;
let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
let src13_idx = dst3_idx; // src1 is not broadcast
- let dst3_rem = global_id.x % dst3_stride;
+ let dst3_rem = global_idx % dst3_stride;
let dst2_idx = dst3_rem / dst2_stride;
let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
- @builtin(local_invocation_id) local_id: vec3<u32>) {
+ @builtin(local_invocation_id) local_id: vec3<u32>,
+ @builtin(num_workgroups) num_wg: vec3<u32>) {
let thread_id = local_id.x;
let local_m = get_local_m(thread_id);
let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M);
let wg_per_matrix = wg_m_count * wg_n_count;
- let batch_idx = wg_id.x / wg_per_matrix;
+ let wg_linear = wg_id.y * num_wg.x + wg_id.x;
- let wg_in_batch = wg_id.x % wg_per_matrix;
+ let batch_idx = wg_linear / wg_per_matrix;
+
+ let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
+ if (batch_idx >= total_batches) {
+ return;
+ }
+
+ let wg_in_batch = wg_linear % wg_per_matrix;
let wg_m = wg_in_batch % wg_m_count;
let wg_n = wg_in_batch / wg_m_count;
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
- @builtin(subgroup_id) subgroup_id: u32) {
+ @builtin(subgroup_id) subgroup_id: u32,
+ @builtin(num_workgroups) num_wg: vec3<u32>) {
let thread_id = local_id.x;
let subgroup_m = subgroup_id % SUBGROUP_M;
let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE;
let wg_per_matrix = wg_m_count * wg_n_count;
- let batch_idx = wg_id.x / wg_per_matrix;
+ let wg_linear = wg_id.y * num_wg.x + wg_id.x;
- let wg_in_batch = wg_id.x % wg_per_matrix;
+ let batch_idx = wg_linear / wg_per_matrix;
+
+ let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
+ if (batch_idx >= total_batches) {
+ return;
+ }
+
+ let wg_in_batch = wg_linear % wg_per_matrix;
let wg_m = wg_in_batch % wg_m_count;
let wg_n = wg_in_batch / wg_m_count;