]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml-webgpu: Update the `RMS_NORM` preprocessor and add `L2_NORM` (llama/20665)
authorMasashi Yoshimura <redacted>
Sat, 28 Mar 2026 09:47:59 +0000 (11:47 +0200)
committerGeorgi Gerganov <redacted>
Sat, 28 Mar 2026 11:39:09 +0000 (13:39 +0200)
* Update the preprocessor of RMS_NORM and add L2_NORM.

* Fix the name of rms_norm to row_norm.

src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
src/ggml-webgpu/ggml-webgpu.cpp
src/ggml-webgpu/wgsl-shaders/row_norm.wgsl [new file with mode: 0644]

index ad665e4de93c157d41dbf5099afd9254b2978c5f..9d16abf20d7ab727d12f8f2cfcdcee926ea1cd95 100644 (file)
@@ -151,6 +151,26 @@ struct ggml_webgpu_get_rows_pipeline_key_hash {
     }
 };
 
+/** Row Norm **/
+
+struct ggml_webgpu_row_norm_pipeline_key {
+    ggml_op op;
+    bool    inplace;
+
+    bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const {
+        return op == other.op && inplace == other.inplace;
+    }
+};
+
+struct ggml_webgpu_row_norm_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.op);
+        ggml_webgpu_hash_combine(seed, key.inplace);
+        return seed;
+    }
+};
+
 /** Pad **/
 struct ggml_webgpu_pad_pipeline_key {
     bool circular;
@@ -438,6 +458,8 @@ class ggml_webgpu_shader_lib {
     std::unordered_map<int, webgpu_pipeline> argsort_pipelines;        // key is order
     std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines;  // key is order
     std::unordered_map<int, webgpu_pipeline> cumsum_pipelines;         // key is fixed, no variants yet
+    std::unordered_map<ggml_webgpu_row_norm_pipeline_key, webgpu_pipeline, ggml_webgpu_row_norm_pipeline_key_hash>
+        row_norm_pipelines;                                            // op/inplace
     std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
         get_rows_pipelines;                                            // src_type, vectorized
     std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
@@ -482,6 +504,44 @@ class ggml_webgpu_shader_lib {
         return sum_rows_pipelines[1];
     }
 
+    webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_row_norm_pipeline_key key = {
+            .op      = context.dst->op,
+            .inplace = context.inplace,
+        };
+
+        auto it = row_norm_pipelines.find(key);
+        if (it != row_norm_pipelines.end()) {
+            return it->second;
+        }
+        std::vector<std::string> defines;
+        std::string              variant;
+
+        switch (key.op) {
+            case GGML_OP_RMS_NORM:
+                defines.push_back("OP_RMS_NORM");
+                variant = "rms_norm";
+                break;
+            case GGML_OP_L2_NORM:
+                defines.push_back("OP_L2_NORM");
+                variant = "l2_norm";
+                break;
+            default:
+                GGML_ABORT("Unsupported op for row_norm shader");
+        }
+
+        if (key.inplace) {
+            defines.push_back("INPLACE");
+            variant += "_inplace";
+        }
+
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed          = preprocessor.preprocess(wgsl_row_norm, defines);
+        row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
+        return row_norm_pipelines[key];
+    }
+
     webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) {
         bool vec4 = context.src0->ne[0] % 4 == 0;
 
index 4b0eeac0f42115981e5708eac751da557a0e89eb..f7973df682abf3921741a3d1660fe6b0f6523292 100644 (file)
@@ -366,7 +366,6 @@ struct webgpu_context_struct {
 
     std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines;                      // src_type, dst_type
 
-    std::map<int, webgpu_pipeline>                               rms_norm_pipelines;  // inplace
     std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines;      // type, ff, inplace
     std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines;       // glu_op, type, split
 
@@ -1598,8 +1597,8 @@ static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src
     return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
-static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
-    int inplace = ggml_webgpu_tensor_equal(src, dst);
+static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+    bool inplace = ggml_webgpu_tensor_equal(src, dst);
 
     std::vector<uint32_t> params = {
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
@@ -1630,8 +1629,15 @@ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * s
                             .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
     }
 
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params,
-                                     entries, ggml_nrows(src));
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src,
+        .dst         = dst,
+        .max_wg_size = WEBGPU_ROW_SPLIT_WG_SIZE,
+        .inplace     = inplace,
+    };
+
+    webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(src));
 }
 
 static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
@@ -2192,7 +2198,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
         case GGML_OP_REPEAT:
             return ggml_webgpu_repeat(ctx, src0, node);
         case GGML_OP_RMS_NORM:
-            return ggml_webgpu_rms_norm(ctx, src0, node);
+        case GGML_OP_L2_NORM:
+            return ggml_webgpu_row_norm(ctx, src0, node);
         case GGML_OP_ROPE:
             return ggml_webgpu_rope(ctx, src0, src1, src2, node);
         case GGML_OP_GLU:
@@ -2616,15 +2623,6 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
         ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
 }
 
-static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
-
-    webgpu_ctx->rms_norm_pipelines[0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants);
-    webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
-}
-
 static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
     std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
 
@@ -2909,7 +2907,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
                               wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
 
     ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
-    ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
     ggml_webgpu_init_rope_pipeline(webgpu_ctx);
     ggml_webgpu_init_glu_pipeline(webgpu_ctx);
     ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
@@ -3120,6 +3117,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
                 break;
             }
         case GGML_OP_RMS_NORM:
+        case GGML_OP_L2_NORM:
             supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
             break;
         case GGML_OP_ROPE:
diff --git a/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl b/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl
new file mode 100644 (file)
index 0000000..7777944
--- /dev/null
@@ -0,0 +1,97 @@
+#ifdef INPLACE
+fn update(src_offset: u32, dst_offset: u32, scale: f32) {
+    src[dst_offset] = scale * src[src_offset];
+}
+
+@group(0) @binding(1)
+var<uniform> params: Params;
+#else
+fn update(src_offset: u32, dst_offset: u32, scale: f32) {
+    dst[dst_offset] = scale * src[src_offset];
+}
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<f32>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+#endif
+
+struct Params {
+    offset_src: u32, // in elements
+    offset_dst: u32, // in elements
+
+    // Strides (in elements)
+    stride_src1: u32,
+    stride_src2: u32,
+    stride_src3: u32,
+
+    stride_dst1: u32,
+    stride_dst2: u32,
+    stride_dst3: u32,
+
+    // Shape of src/dst
+    ne0: u32,
+    ne1: u32,
+    ne2: u32,
+    ne3: u32,
+
+    eps: f32
+};
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+var<workgroup> scratch: array<f32, WG_SIZE>;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+        @builtin(local_invocation_id) lid: vec3<u32>) {
+
+    // one thread per row
+    var i = wid.x;
+    let i3 = i / (params.ne2 * params.ne1);
+    i = i % (params.ne2 * params.ne1);
+    let i2 = i / params.ne1;
+    let i1 = i % params.ne1;
+    let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
+    let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
+
+    let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
+
+    var sum = 0.0f;
+    var col = lid.x;
+    for (var j: u32 = 0; j < elems; j++) {
+        if (col >= params.ne0) {
+            break;
+        }
+        sum += pow(src[i_src_row + col], 2.0);
+        col += WG_SIZE;
+    }
+
+    scratch[lid.x] = sum;
+    workgroupBarrier();
+    var offset: u32 = WG_SIZE / 2;
+    while (offset > 0) {
+        if (lid.x < offset) {
+            scratch[lid.x] += scratch[lid.x + offset];
+        }
+        offset = offset / 2;
+        workgroupBarrier();
+    }
+    sum = scratch[0];
+
+#ifdef OP_RMS_NORM
+    let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
+#elif OP_L2_NORM
+    let scale = 1.0/max(sqrt(sum), params.eps);
+#endif
+    col = lid.x;
+    for (var j: u32 = 0; j < elems; j++) {
+        if (col >= params.ne0) {
+            break;
+        }
+        update(i_src_row + col, i_dst_row + col, scale);
+        col += WG_SIZE;
+    }
+}