std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
set_rows_pipelines;
- std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
+ std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
- std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
- std::map<int, std::map<int, webgpu_pipeline>> add_pipelines; // type, inplace
- std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines; // type, inplace
- std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines; // type, inplace
- std::map<int, std::map<int, webgpu_pipeline>> div_pipelines; // type, inplace
+ std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
+
+ std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
+ binary_pipelines;
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
(ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
}
+// Used to determine if two tensors share the same buffer and their byte ranges overlap,
+static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
+ return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
+ ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&
+ ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));
+}
+
+struct binary_overlap_flags {
+ bool inplace; // src0 == dst
+ bool overlap; // src1 == dst
+};
+
+static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
+ ggml_tensor * src1,
+ ggml_tensor * dst) {
+ binary_overlap_flags flags = {};
+ flags.inplace = ggml_webgpu_tensor_equal(src0, dst);
+ flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
+
+ return flags;
+}
+
static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
uint32_t ne = (uint32_t) ggml_nelements(dst);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
-static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
- ggml_tensor * src0,
- ggml_tensor * src1,
- ggml_tensor * dst,
- webgpu_pipeline & pipeline,
- bool inplace) {
+static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
+ ggml_tensor * src0,
+ ggml_tensor * src1,
+ ggml_tensor * dst) {
+ binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
+
+ ggml_webgpu_binary_pipeline_key pipeline_key = {
+ .type = dst->type,
+ .op = dst->op,
+ .inplace = flags.inplace,
+ .overlap = flags.overlap,
+ };
+ ggml_webgpu_binary_shader_lib_context shader_lib_ctx = {
+ .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
+ };
+
+ webgpu_pipeline pipeline;
+ auto it = ctx->binary_pipelines.find(pipeline_key);
+ if (it != ctx->binary_pipelines.end()) {
+ pipeline = it->second;
+ } else {
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_binary_shader(ctx->p, wgsl_binary, shader_lib_ctx);
+ pipeline =
+ ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+ pipeline.context = processed.decisions;
+ ctx->binary_pipelines.emplace(pipeline_key, pipeline);
+ }
+
+ ggml_webgpu_generic_shader_decisions decisions =
+ *static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context);
+
+ uint32_t ne = (uint32_t) ggml_nelements(dst);
+
std::vector<uint32_t> params = {
- (uint32_t) ggml_nelements(dst),
+ ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) src1->ne[3],
};
- std::vector<wgpu::BindGroupEntry> entries = {
- { .binding = 0,
- .buffer = ggml_webgpu_tensor_buf(src0),
- .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
- .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
- { .binding = 1,
- .buffer = ggml_webgpu_tensor_buf(src1),
- .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
- .size = ggml_webgpu_tensor_binding_size(ctx, src1) }
- };
- if (!inplace) {
+ std::vector<wgpu::BindGroupEntry> entries;
+
+ entries.push_back({
+ .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src0),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0),
+ });
+
+ entries.push_back({
+ .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(src1),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1),
+ });
+
+ if (!flags.inplace && !flags.overlap) {
entries.push_back({ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
- uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
+ uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
return std::nullopt;
#endif
case GGML_OP_ADD:
- {
- int inplace = ggml_webgpu_tensor_equal(src0, node);
- return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipelines[node->type][inplace], inplace);
- }
case GGML_OP_SUB:
- {
- int inplace = ggml_webgpu_tensor_equal(src0, node);
- return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipelines[node->type][inplace], inplace);
- }
case GGML_OP_MUL:
- {
- int inplace = ggml_webgpu_tensor_equal(src0, node);
- return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipelines[node->type][inplace], inplace);
- }
case GGML_OP_DIV:
- {
- int inplace = ggml_webgpu_tensor_equal(src0, node);
- return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipelines[node->type][inplace], inplace);
- }
+ return ggml_webgpu_binary_op(ctx, src0, src1, node);
case GGML_OP_RMS_NORM:
return ggml_webgpu_rms_norm(ctx, src0, node);
case GGML_OP_ROPE:
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
}
-static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
- webgpu_ctx->add_pipelines[GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f32, "add_f32", constants);
- webgpu_ctx->add_pipelines[GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f16, "add_f16", constants);
- webgpu_ctx->add_pipelines[GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants);
- webgpu_ctx->add_pipelines[GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants);
-}
-
-static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
- webgpu_ctx->sub_pipelines[GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f32, "sub_f32", constants);
- webgpu_ctx->sub_pipelines[GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f16, "sub_f16", constants);
- webgpu_ctx->sub_pipelines[GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants);
- webgpu_ctx->sub_pipelines[GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants);
-}
-
-static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
- webgpu_ctx->mul_pipelines[GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f32, "mul_f32", constants);
- webgpu_ctx->mul_pipelines[GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f16, "mul_f16", constants);
- webgpu_ctx->mul_pipelines[GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants);
- webgpu_ctx->mul_pipelines[GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants);
-}
-
-static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
- webgpu_ctx->div_pipelines[GGML_TYPE_F32][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f32, "div_f32", constants);
- webgpu_ctx->div_pipelines[GGML_TYPE_F16][0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f16, "div_f16", constants);
- webgpu_ctx->div_pipelines[GGML_TYPE_F32][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants);
- webgpu_ctx->div_pipelines[GGML_TYPE_F16][1] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants);
-}
-
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
ggml_webgpu_init_get_rows_pipeline(webgpu_ctx);
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
- ggml_webgpu_init_add_pipeline(webgpu_ctx);
- ggml_webgpu_init_sub_pipeline(webgpu_ctx);
- ggml_webgpu_init_mul_pipeline(webgpu_ctx);
- ggml_webgpu_init_div_pipeline(webgpu_ctx);
ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
ggml_webgpu_init_rope_pipeline(webgpu_ctx);
ggml_webgpu_init_glu_pipeline(webgpu_ctx);
+++ /dev/null
-#define(VARIANTS)
-
-[
- {
- "SHADER_NAME": "add_f32",
- "REPLS": {
- "TYPE" : "f32",
- "OP": "+"
- },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "add_f16",
- "REPLS": {
- "TYPE" : "f16",
- "OP": "+"
- },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "add_f32_inplace",
- "REPLS": {
- "TYPE" : "f32",
- "OP": "+"
- },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "add_f16_inplace",
- "REPLS": {
- "TYPE" : "f16",
- "OP": "+"
- },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "mul_f32",
- "REPLS": {
- "TYPE" : "f32",
- "OP": "*"
- },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "mul_f16",
- "REPLS": {
- "TYPE" : "f16",
- "OP": "*"
- },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "mul_f32_inplace",
- "REPLS": {
- "TYPE" : "f32",
- "OP": "*"
- },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "mul_f16_inplace",
- "REPLS": {
- "TYPE" : "f16",
- "OP": "*"
- },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "sub_f32",
- "REPLS": {
- "TYPE" : "f32",
- "OP": "-"
- },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "sub_f16",
- "REPLS": {
- "TYPE" : "f16",
- "OP": "-"
- },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "sub_f32_inplace",
- "REPLS": {
- "TYPE" : "f32",
- "OP": "-"
- },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "sub_f16_inplace",
- "REPLS": {
- "TYPE" : "f16",
- "OP": "-"
- },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "div_f32",
- "REPLS": {
- "TYPE" : "f32",
- "OP": "/"
- },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "div_f16",
- "REPLS": {
- "TYPE" : "f16",
- "OP": "/"
- },
- "DECLS": ["NOT_INPLACE"]
- },
- {
- "SHADER_NAME": "div_f32_inplace",
- "REPLS": {
- "TYPE" : "f32",
- "OP": "/"
- },
- "DECLS": ["INPLACE"]
- },
- {
- "SHADER_NAME": "div_f16_inplace",
- "REPLS": {
- "TYPE" : "f16",
- "OP": "/"
- },
- "DECLS": ["INPLACE"]
- }
-]
-
-#end(VARIANTS)
-
-#define(DECLS)
-
-#decl(NOT_INPLACE)
-
-fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
- dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
-}
-
-@group(0) @binding(2)
-var<storage, read_write> dst: array<{{TYPE}}>;
-
-@group(0) @binding(3)
-var<uniform> params: Params;
-
-#enddecl(NOT_INPLACE)
-
-#decl(INPLACE)
-
-fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
- src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
-}
-
-@group(0) @binding(2)
-var<uniform> params: Params;
-
-#enddecl(INPLACE)
-
-#end(DECLS)
-
-
-#define(SHADER)
-
-enable f16;
-
-#include "binary_head.tmpl"
-
-@group(0) @binding(0)
-var<storage, read_write> src0: array<{{TYPE}}>;
-
-@group(0) @binding(1)
-var<storage, read_write> src1: array<{{TYPE}}>;
-
-DECLS
-
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
- if (gid.x < params.ne) {
- update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
- }
-}
-
-#end(SHADER)