struct webgpu_buf_pool {
std::vector<webgpu_pool_bufs> free;
- std::mutex mutex;
-
+ // The pool must be synchronized because
+ // 1. The memset pool is shared globally by every ggml buffer,
+ // since allocating a pool per ggml buffer would consume too much memory.
+ // 2. For the per-thread buffer pools in webgpu_context,
+ // buffers are allocated and freed in Dawn callbacks,
+ // which can run on a different thread than the calling thread.
+ std::mutex mutex;
std::condition_variable cv;
void init(wgpu::Device device,
#endif
};
-struct webgpu_capabilities_base {
+struct webgpu_capabilities {
wgpu::Limits limits;
bool supports_subgroup_matrix = false;
wgpu::Device device;
wgpu::Queue queue;
- webgpu_capabilities_base capabilities;
+ webgpu_capabilities capabilities;
// Shared buffer to move data from device to host
- wgpu::Buffer get_tensor_staging_buf;
+ wgpu::Buffer get_tensor_staging_buf;
// Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches.
- std::recursive_mutex mutex;
+ std::recursive_mutex mutex;
webgpu_buf_pool memset_buf_pool;
std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash> pad_pipelines;
size_t memset_bytes_per_thread;
-
};
typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
// Per-thread data required to actually run WebGPU operations in a backend instance
struct ggml_backend_webgpu_context {
- webgpu_context webgpu_ctx;
- std::once_flag init_once;
- std::string name;
+ webgpu_context webgpu_ctx;
+ std::string name;
};
// Per-thread data related to buffers
};
webgpu_pipeline pipeline;
- {
- // TODO: remove guard once pipeline caches are per-thread
- std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
- auto it = ctx->pad_pipelines.find(pipeline_key);
- if (it != ctx->pad_pipelines.end()) {
- pipeline = it->second;
- } else {
- ggml_webgpu_processed_shader processed =
- ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx);
- pipeline =
- ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
- pipeline.context = processed.decisions;
- ctx->pad_pipelines.emplace(pipeline_key, pipeline);
- }
+ auto it = ctx->pad_pipelines.find(pipeline_key);
+ if (it != ctx->pad_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx);
+ pipeline =
+ ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ pipeline.context = processed.decisions;
+ ctx->pad_pipelines.emplace(pipeline_key, pipeline);
}
ggml_webgpu_generic_shader_decisions decisions =
};
webgpu_pipeline pipeline;
- // TODO: remove guard once pipeline caches are per-thread
- {
- std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
- auto it = ctx->set_rows_pipelines.find(key);
- if (it != ctx->set_rows_pipelines.end()) {
- pipeline = it->second;
- } else {
- ggml_webgpu_processed_shader processed =
- ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx);
- pipeline =
- ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
- pipeline.context = processed.decisions;
- ctx->set_rows_pipelines.emplace(key, pipeline);
- }
+ auto it = ctx->set_rows_pipelines.find(key);
+ if (it != ctx->set_rows_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx);
+ pipeline =
+ ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ pipeline.context = processed.decisions;
+ ctx->set_rows_pipelines.emplace(key, pipeline);
}
ggml_webgpu_generic_shader_decisions decisions =
};
webgpu_pipeline pipeline;
- // TODO: remove guard once pipeline caches are per-thread
- {
- std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
- auto it = ctx->flash_attn_pipelines.find(key);
- if (it != ctx->flash_attn_pipelines.end()) {
- pipeline = it->second;
- } else {
- ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
- .key = key,
- .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
- .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
- .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
- .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
- .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size
- };
-
- ggml_webgpu_processed_shader processed =
- ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
- pipeline =
- ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
- pipeline.context = processed.decisions;
- ctx->flash_attn_pipelines.emplace(key, pipeline);
- }
+ auto it = ctx->flash_attn_pipelines.find(key);
+ if (it != ctx->flash_attn_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
+ .key = key,
+ .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
+ .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
+ .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
+ .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
+ .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size
+ };
+
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
+ pipeline =
+ ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ pipeline.context = processed.decisions;
+ ctx->flash_attn_pipelines.emplace(key, pipeline);
}
ggml_webgpu_flash_attn_shader_decisions decisions =
};
webgpu_pipeline pipeline;
- {
- // TODO: remove guard once pipeline caches are per-thread
- std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
- auto it = ctx->unary_pipelines.find(pipeline_key);
- if (it != ctx->unary_pipelines.end()) {
- pipeline = it->second;
- } else {
- ggml_webgpu_processed_shader processed =
- ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx);
- pipeline =
- ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
- pipeline.context = processed.decisions;
- ctx->unary_pipelines.emplace(pipeline_key, pipeline);
- }
+ auto it = ctx->unary_pipelines.find(pipeline_key);
+ if (it != ctx->unary_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx);
+ pipeline =
+ ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ pipeline.context = processed.decisions;
+ ctx->unary_pipelines.emplace(pipeline_key, pipeline);
}
ggml_webgpu_generic_shader_decisions decisions =
};
webgpu_pipeline pipeline;
- {
- // TODO: remove guard once pipeline caches are per-thread
- std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
- auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4);
- if (it != ctx->argmax_pipelines.end()) {
- pipeline = it->second;
- } else {
- ggml_webgpu_processed_shader processed =
- ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax");
- pipeline =
- ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
- ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline);
- }
+ auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4);
+ if (it != ctx->argmax_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax");
+ pipeline =
+ ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline);
}
uint32_t wg_x = ggml_nelements(dst);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
.order = order
};
- std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
- webgpu_pipeline argsort_pipeline;
- auto it = ctx->argsort_pipelines.find(order);
+ webgpu_pipeline argsort_pipeline;
+ auto it = ctx->argsort_pipelines.find(order);
if (it != ctx->argsort_pipelines.end()) {
argsort_pipeline = it->second;
} else {
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline;
- // TODO: remove guard once pipeline caches are per-thread
- {
- std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
- auto it = ctx->cumsum_pipelines.find(1);
- if (it != ctx->cumsum_pipelines.end()) {
- pipeline = it->second;
- } else {
- ggml_webgpu_processed_shader processed =
- ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum");
- pipeline =
- ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
- ctx->cumsum_pipelines.emplace(1, pipeline);
- }
+ auto it = ctx->cumsum_pipelines.find(1);
+ if (it != ctx->cumsum_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum");
+ pipeline =
+ ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ ctx->cumsum_pipelines.emplace(1, pipeline);
}
uint32_t wg_x = ggml_nrows(dst);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
};
webgpu_pipeline pipeline;
- {
- // TODO: remove guard once pipeline caches are per-thread
- std::lock_guard<std::recursive_mutex> lock(ctx->global_ctx->mutex);
- auto it = ctx->sum_rows_pipelines.find(1);
- if (it != ctx->sum_rows_pipelines.end()) {
- pipeline = it->second;
- } else {
- ggml_webgpu_processed_shader processed =
- ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows");
- pipeline =
- ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
- ctx->sum_rows_pipelines.emplace(1, pipeline);
- }
+ auto it = ctx->sum_rows_pipelines.find(1);
+ if (it != ctx->sum_rows_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows");
+ pipeline =
+ ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ ctx->sum_rows_pipelines.emplace(1, pipeline);
}
uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
#ifdef GGML_WEBGPU_GPU_PROFILE
// Initialize buffer pool for timestamp queries, used for profiling
- ctx->webgpu_global_ctx->timestamp_query_buf_pool.init(ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS,
- WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
- wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
- wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
+ ctx->webgpu_global_ctx->timestamp_query_buf_pool.init(
+ ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
+ wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
+ wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
#endif
GGML_LOG_INFO(