* Update the preprocessor of RMS_NORM and add L2_NORM.
* Fix the name of rms_norm to row_norm.
}
};
+/** Row Norm **/
+
+struct ggml_webgpu_row_norm_pipeline_key {
+ ggml_op op;
+ bool inplace;
+
+ bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const {
+ return op == other.op && inplace == other.inplace;
+ }
+};
+
+struct ggml_webgpu_row_norm_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.op);
+ ggml_webgpu_hash_combine(seed, key.inplace);
+ return seed;
+ }
+};
+
/** Pad **/
struct ggml_webgpu_pad_pipeline_key {
bool circular;
std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order
std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
+ std::unordered_map<ggml_webgpu_row_norm_pipeline_key, webgpu_pipeline, ggml_webgpu_row_norm_pipeline_key_hash>
+ row_norm_pipelines; // op/inplace
std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
get_rows_pipelines; // src_type, vectorized
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
return sum_rows_pipelines[1];
}
+ webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
+ ggml_webgpu_row_norm_pipeline_key key = {
+ .op = context.dst->op,
+ .inplace = context.inplace,
+ };
+
+ auto it = row_norm_pipelines.find(key);
+ if (it != row_norm_pipelines.end()) {
+ return it->second;
+ }
+ std::vector<std::string> defines;
+ std::string variant;
+
+ switch (key.op) {
+ case GGML_OP_RMS_NORM:
+ defines.push_back("OP_RMS_NORM");
+ variant = "rms_norm";
+ break;
+ case GGML_OP_L2_NORM:
+ defines.push_back("OP_L2_NORM");
+ variant = "l2_norm";
+ break;
+ default:
+ GGML_ABORT("Unsupported op for row_norm shader");
+ }
+
+ if (key.inplace) {
+ defines.push_back("INPLACE");
+ variant += "_inplace";
+ }
+
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+ auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
+ row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
+ return row_norm_pipelines[key];
+ }
+
webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) {
bool vec4 = context.src0->ne[0] % 4 == 0;
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
- std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
-static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
- int inplace = ggml_webgpu_tensor_equal(src, dst);
+static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+ bool inplace = ggml_webgpu_tensor_equal(src, dst);
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
- return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params,
- entries, ggml_nrows(src));
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
+ .src0 = src,
+ .dst = dst,
+ .max_wg_size = WEBGPU_ROW_SPLIT_WG_SIZE,
+ .inplace = inplace,
+ };
+
+ webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx);
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(src));
}
static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
case GGML_OP_REPEAT:
return ggml_webgpu_repeat(ctx, src0, node);
case GGML_OP_RMS_NORM:
- return ggml_webgpu_rms_norm(ctx, src0, node);
+ case GGML_OP_L2_NORM:
+ return ggml_webgpu_row_norm(ctx, src0, node);
case GGML_OP_ROPE:
return ggml_webgpu_rope(ctx, src0, src1, src2, node);
case GGML_OP_GLU:
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
}
-static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
- std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
-
- webgpu_ctx->rms_norm_pipelines[0] =
- ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants);
- webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline(
- webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
-}
-
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
- ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
ggml_webgpu_init_rope_pipeline(webgpu_ctx);
ggml_webgpu_init_glu_pipeline(webgpu_ctx);
ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
break;
}
case GGML_OP_RMS_NORM:
+ case GGML_OP_L2_NORM:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;
case GGML_OP_ROPE:
--- /dev/null
+#ifdef INPLACE
+fn update(src_offset: u32, dst_offset: u32, scale: f32) {
+ src[dst_offset] = scale * src[src_offset];
+}
+
+@group(0) @binding(1)
+var<uniform> params: Params;
+#else
+fn update(src_offset: u32, dst_offset: u32, scale: f32) {
+ dst[dst_offset] = scale * src[src_offset];
+}
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<f32>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+#endif
+
+struct Params {
+ offset_src: u32, // in elements
+ offset_dst: u32, // in elements
+
+ // Strides (in elements)
+ stride_src1: u32,
+ stride_src2: u32,
+ stride_src3: u32,
+
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ // Shape of src/dst
+ ne0: u32,
+ ne1: u32,
+ ne2: u32,
+ ne3: u32,
+
+ eps: f32
+};
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+var<workgroup> scratch: array<f32, WG_SIZE>;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+ @builtin(local_invocation_id) lid: vec3<u32>) {
+
+ // one thread per row
+ var i = wid.x;
+ let i3 = i / (params.ne2 * params.ne1);
+ i = i % (params.ne2 * params.ne1);
+ let i2 = i / params.ne1;
+ let i1 = i % params.ne1;
+ let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
+ let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
+
+ let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
+
+ var sum = 0.0f;
+ var col = lid.x;
+ for (var j: u32 = 0; j < elems; j++) {
+ if (col >= params.ne0) {
+ break;
+ }
+ sum += pow(src[i_src_row + col], 2.0);
+ col += WG_SIZE;
+ }
+
+ scratch[lid.x] = sum;
+ workgroupBarrier();
+ var offset: u32 = WG_SIZE / 2;
+ while (offset > 0) {
+ if (lid.x < offset) {
+ scratch[lid.x] += scratch[lid.x + offset];
+ }
+ offset = offset / 2;
+ workgroupBarrier();
+ }
+ sum = scratch[0];
+
+#ifdef OP_RMS_NORM
+ let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
+#elif OP_L2_NORM
+ let scale = 1.0/max(sqrt(sum), params.eps);
+#endif
+ col = lid.x;
+ for (var j: u32 = 0; j < elems; j++) {
+ if (col >= params.ne0) {
+ break;
+ }
+ update(i_src_row + col, i_dst_row + col, scale);
+ col += WG_SIZE;
+ }
+}