]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-webgpu: Update the `RMS_NORM` preprocessor and add `L2_NORM` (#20665)
authorMasashi Yoshimura <redacted>
Thu, 19 Mar 2026 04:08:59 +0000 (13:08 +0900)
committerGitHub <redacted>
Thu, 19 Mar 2026 04:08:59 +0000 (21:08 -0700)
* Update the preprocessor of RMS_NORM and add L2_NORM.

* Fix the name of rms_norm to row_norm.

docs/ops.md
docs/ops/WebGPU.csv
ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
ggml/src/ggml-webgpu/ggml-webgpu.cpp
ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl [deleted file]
ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl [new file with mode: 0644]

index 47534b1401c92f550df226fdd0e851c2ea9dc156..14b53f2c40510be37b21e538ded0a2a779f1054c 100644 (file)
@@ -62,7 +62,7 @@ Legend:
 |                        HARDSWISH | โŒ | โœ… | โœ… | ๐ŸŸก | ๐ŸŸก | โŒ | โœ… | ๐ŸŸก | โœ… | โŒ | โŒ |
 |                           IM2COL | โŒ | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โŒ | โŒ | โŒ |
 |                        IM2COL_3D | โŒ | โŒ | โœ… | โœ… | โŒ | โŒ | โŒ | โœ… | โŒ | โŒ | โŒ |
-|                          L2_NORM | รข\9d\8c | รข\9c\85 | รข\9c\85 | รข\9c\85 | รข\9c\85 | รข\9d\8c | รข\9c\85 | รข\9c\85 | รข\9d\8c | โŒ | โŒ |
+|                          L2_NORM | รข\9d\8c | รข\9c\85 | รข\9c\85 | รข\9c\85 | รข\9c\85 | รข\9d\8c | รข\9c\85 | รข\9c\85 | รข\9c\85 | โŒ | โŒ |
 |                       LEAKY_RELU | โŒ | โœ… | โœ… | โœ… | ๐ŸŸก | โŒ | โœ… | ๐ŸŸก | โŒ | โŒ | โŒ |
 |                              LOG | โŒ | โœ… | โœ… | โœ… | ๐ŸŸก | โŒ | ๐ŸŸก | โœ… | โœ… | โŒ | โŒ |
 |                             MEAN | โŒ | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โœ… | โŒ | โŒ | โŒ |
index 56bae2f3c81d7bb30dc82eb46369a88c9e3e6f9c..4b735d45799a44e519a38e85cb48aec922a8ba5e 100644 (file)
 "WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000","support","0","no","WebGPU"
 "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000,inplace=0","support","1","yes","WebGPU"
 "WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","0","no","WebGPU"
-"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3]","support","0","no","WebGPU"
+"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000000,v=0","support","1","yes","WebGPU"
+"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000000,v=1","support","1","yes","WebGPU"
 "WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000000","support","0","no","WebGPU"
 "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000000,inplace=0","support","1","yes","WebGPU"
 "WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000000","support","0","no","WebGPU"
 "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000000,inplace=0","support","1","yes","WebGPU"
 "WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.000000","support","0","no","WebGPU"
-"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3]","support","0","no","WebGPU"
+"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000000,v=0","support","1","yes","WebGPU"
+"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000000,v=1","support","1","yes","WebGPU"
 "WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001","support","0","no","WebGPU"
 "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=0","support","1","yes","WebGPU"
 "WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001","support","0","no","WebGPU"
 "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001,inplace=0","support","1","yes","WebGPU"
 "WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","0","no","WebGPU"
-"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3]","support","0","no","WebGPU"
+"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000001,v=0","support","1","yes","WebGPU"
+"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000001,v=1","support","1","yes","WebGPU"
 "WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000001","support","0","no","WebGPU"
 "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000001,inplace=0","support","1","yes","WebGPU"
 "WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000001","support","0","no","WebGPU"
 "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000001,inplace=0","support","1","yes","WebGPU"
 "WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.000001","support","0","no","WebGPU"
-"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3]","support","0","no","WebGPU"
+"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000001,v=0","support","1","yes","WebGPU"
+"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000001,v=1","support","1","yes","WebGPU"
 "WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100","support","0","no","WebGPU"
 "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100,inplace=0","support","1","yes","WebGPU"
 "WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100","support","0","no","WebGPU"
 "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100,inplace=0","support","1","yes","WebGPU"
 "WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","0","no","WebGPU"
-"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3]","support","0","no","WebGPU"
+"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000100,v=0","support","1","yes","WebGPU"
+"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000100,v=1","support","1","yes","WebGPU"
 "WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000100","support","0","no","WebGPU"
 "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000100,inplace=0","support","1","yes","WebGPU"
 "WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000100","support","0","no","WebGPU"
 "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000100,inplace=0","support","1","yes","WebGPU"
 "WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.000100","support","0","no","WebGPU"
-"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3]","support","0","no","WebGPU"
+"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000100,v=0","support","1","yes","WebGPU"
+"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000100,v=1","support","1","yes","WebGPU"
 "WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000","support","0","no","WebGPU"
 "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000,inplace=0","support","1","yes","WebGPU"
 "WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000","support","0","no","WebGPU"
 "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000,inplace=0","support","1","yes","WebGPU"
 "WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","0","no","WebGPU"
-"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3]","support","0","no","WebGPU"
+"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.100000,v=0","support","1","yes","WebGPU"
+"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.100000,v=1","support","1","yes","WebGPU"
 "WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.100000","support","0","no","WebGPU"
 "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.100000,inplace=0","support","1","yes","WebGPU"
 "WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.100000","support","0","no","WebGPU"
 "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.100000,inplace=0","support","1","yes","WebGPU"
 "WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.100000","support","0","no","WebGPU"
-"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3]","support","0","no","WebGPU"
+"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.100000,v=0","support","1","yes","WebGPU"
+"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.100000,v=1","support","1","yes","WebGPU"
+"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=10.000000,v=0","support","1","yes","WebGPU"
+"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=10.000000,v=1","support","1","yes","WebGPU"
+"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=10.000000,v=0","support","1","yes","WebGPU"
+"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=10.000000,v=1","support","1","yes","WebGPU"
 "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=1","support","1","yes","WebGPU"
 "WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[3,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","WebGPU"
 "WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[6,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","WebGPU"
index ad665e4de93c157d41dbf5099afd9254b2978c5f..9d16abf20d7ab727d12f8f2cfcdcee926ea1cd95 100644 (file)
@@ -151,6 +151,26 @@ struct ggml_webgpu_get_rows_pipeline_key_hash {
     }
 };
 
+/** Row Norm **/
+
+struct ggml_webgpu_row_norm_pipeline_key {
+    ggml_op op;
+    bool    inplace;
+
+    bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const {
+        return op == other.op && inplace == other.inplace;
+    }
+};
+
+struct ggml_webgpu_row_norm_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.op);
+        ggml_webgpu_hash_combine(seed, key.inplace);
+        return seed;
+    }
+};
+
 /** Pad **/
 struct ggml_webgpu_pad_pipeline_key {
     bool circular;
@@ -438,6 +458,8 @@ class ggml_webgpu_shader_lib {
     std::unordered_map<int, webgpu_pipeline> argsort_pipelines;        // key is order
     std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines;  // key is order
     std::unordered_map<int, webgpu_pipeline> cumsum_pipelines;         // key is fixed, no variants yet
+    std::unordered_map<ggml_webgpu_row_norm_pipeline_key, webgpu_pipeline, ggml_webgpu_row_norm_pipeline_key_hash>
+        row_norm_pipelines;                                            // op/inplace
     std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
         get_rows_pipelines;                                            // src_type, vectorized
     std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
@@ -482,6 +504,44 @@ class ggml_webgpu_shader_lib {
         return sum_rows_pipelines[1];
     }
 
+    webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_row_norm_pipeline_key key = {
+            .op      = context.dst->op,
+            .inplace = context.inplace,
+        };
+
+        auto it = row_norm_pipelines.find(key);
+        if (it != row_norm_pipelines.end()) {
+            return it->second;
+        }
+        std::vector<std::string> defines;
+        std::string              variant;
+
+        switch (key.op) {
+            case GGML_OP_RMS_NORM:
+                defines.push_back("OP_RMS_NORM");
+                variant = "rms_norm";
+                break;
+            case GGML_OP_L2_NORM:
+                defines.push_back("OP_L2_NORM");
+                variant = "l2_norm";
+                break;
+            default:
+                GGML_ABORT("Unsupported op for row_norm shader");
+        }
+
+        if (key.inplace) {
+            defines.push_back("INPLACE");
+            variant += "_inplace";
+        }
+
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed          = preprocessor.preprocess(wgsl_row_norm, defines);
+        row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
+        return row_norm_pipelines[key];
+    }
+
     webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) {
         bool vec4 = context.src0->ne[0] % 4 == 0;
 
index 4b0eeac0f42115981e5708eac751da557a0e89eb..f7973df682abf3921741a3d1660fe6b0f6523292 100644 (file)
@@ -366,7 +366,6 @@ struct webgpu_context_struct {
 
     std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines;                      // src_type, dst_type
 
-    std::map<int, webgpu_pipeline>                               rms_norm_pipelines;  // inplace
     std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines;      // type, ff, inplace
     std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines;       // glu_op, type, split
 
@@ -1598,8 +1597,8 @@ static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src
     return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
-static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
-    int inplace = ggml_webgpu_tensor_equal(src, dst);
+static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+    bool inplace = ggml_webgpu_tensor_equal(src, dst);
 
     std::vector<uint32_t> params = {
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
@@ -1630,8 +1629,15 @@ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * s
                             .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
     }
 
-    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params,
-                                     entries, ggml_nrows(src));
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src,
+        .dst         = dst,
+        .max_wg_size = WEBGPU_ROW_SPLIT_WG_SIZE,
+        .inplace     = inplace,
+    };
+
+    webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(src));
 }
 
 static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
@@ -2192,7 +2198,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
         case GGML_OP_REPEAT:
             return ggml_webgpu_repeat(ctx, src0, node);
         case GGML_OP_RMS_NORM:
-            return ggml_webgpu_rms_norm(ctx, src0, node);
+        case GGML_OP_L2_NORM:
+            return ggml_webgpu_row_norm(ctx, src0, node);
         case GGML_OP_ROPE:
             return ggml_webgpu_rope(ctx, src0, src1, src2, node);
         case GGML_OP_GLU:
@@ -2616,15 +2623,6 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
         ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
 }
 
-static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
-
-    webgpu_ctx->rms_norm_pipelines[0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants);
-    webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
-}
-
 static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
     std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
 
@@ -2909,7 +2907,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
                               wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
 
     ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
-    ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
     ggml_webgpu_init_rope_pipeline(webgpu_ctx);
     ggml_webgpu_init_glu_pipeline(webgpu_ctx);
     ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
@@ -3120,6 +3117,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
                 break;
             }
         case GGML_OP_RMS_NORM:
+        case GGML_OP_L2_NORM:
             supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
             break;
         case GGML_OP_ROPE:
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl
deleted file mode 100644 (file)
index 712b921..0000000
+++ /dev/null
@@ -1,123 +0,0 @@
-#define(VARIANTS)
-
-[
-  {
-    "DECLS": ["NOT_INPLACE"]
-  },
-  {
-    "SHADER_SUFFIX": "inplace",
-    "DECLS": ["INPLACE"]
-  },
-]
-
-#end(VARIANTS)
-
-#define(DECLS)
-
-#decl(NOT_INPLACE)
-
-fn update(src_offset: u32, dst_offset: u32, scale: f32) {
-    dst[dst_offset] = scale * src[src_offset];
-}
-
-@group(0) @binding(1)
-var<storage, read_write> dst: array<f32>;
-
-@group(0) @binding(2)
-var<uniform> params: Params;
-
-#enddecl(NOT_INPLACE)
-
-#decl(INPLACE)
-
-fn update(src_offset: u32, dst_offset: u32, scale: f32) {
-    src[dst_offset] = scale * src[src_offset];
-}
-
-@group(0) @binding(1)
-var<uniform> params: Params;
-
-#enddecl(INPLACE)
-
-#end(DECLS)
-
-#define(SHADER)
-
-struct Params {
-    offset_src: u32, // in elements
-    offset_dst: u32, // in elements
-
-    // Strides (in elements)
-    stride_src1: u32,
-    stride_src2: u32,
-    stride_src3: u32,
-
-    stride_dst1: u32,
-    stride_dst2: u32,
-    stride_dst3: u32,
-
-    // Shape of src/dst
-    ne0: u32,
-    ne1: u32,
-    ne2: u32,
-    ne3: u32,
-
-    eps: f32
-};
-
-@group(0) @binding(0)
-var<storage, read_write> src: array<f32>;
-
-DECLS
-
-override wg_size: u32;
-var<workgroup> scratch: array<f32, wg_size>;
-
-@compute @workgroup_size(wg_size)
-fn main(@builtin(workgroup_id) wid: vec3<u32>,
-        @builtin(local_invocation_id) lid: vec3<u32>) {
-
-    // one thread per row
-    var i = wid.x;
-    let i3 = i / (params.ne2 * params.ne1);
-    i = i % (params.ne2 * params.ne1);
-    let i2 = i / params.ne1;
-    let i1 = i % params.ne1;
-    let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
-    let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
-
-    let elems = (params.ne0 + wg_size - 1) / wg_size;
-
-    var sum = 0.0f;
-    var col = lid.x;
-    for (var j: u32 = 0; j < elems; j++) {
-        if (col >= params.ne0) {
-            break;
-        }
-        sum += pow(src[i_src_row + col], 2.0);
-        col += wg_size;
-    }
-
-    scratch[lid.x] = sum;
-    workgroupBarrier();
-    var offset = wg_size / 2;
-    while (offset > 0) {
-        if (lid.x < offset) {
-            scratch[lid.x] += scratch[lid.x + offset];
-        }
-        offset = offset / 2;
-        workgroupBarrier();
-    }
-    sum = scratch[0];
-
-    let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
-    col = lid.x;
-    for (var j: u32 = 0; j < elems; j++) {
-        if (col >= params.ne0) {
-            break;
-        }
-        update(i_src_row + col, i_dst_row + col, scale);
-        col += wg_size;
-    }
-}
-#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl
new file mode 100644 (file)
index 0000000..7777944
--- /dev/null
@@ -0,0 +1,97 @@
+#ifdef INPLACE
+fn update(src_offset: u32, dst_offset: u32, scale: f32) {
+    src[dst_offset] = scale * src[src_offset];
+}
+
+@group(0) @binding(1)
+var<uniform> params: Params;
+#else
+fn update(src_offset: u32, dst_offset: u32, scale: f32) {
+    dst[dst_offset] = scale * src[src_offset];
+}
+
+@group(0) @binding(1)
+var<storage, read_write> dst: array<f32>;
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+#endif
+
+struct Params {
+    offset_src: u32, // in elements
+    offset_dst: u32, // in elements
+
+    // Strides (in elements)
+    stride_src1: u32,
+    stride_src2: u32,
+    stride_src3: u32,
+
+    stride_dst1: u32,
+    stride_dst2: u32,
+    stride_dst3: u32,
+
+    // Shape of src/dst
+    ne0: u32,
+    ne1: u32,
+    ne2: u32,
+    ne3: u32,
+
+    eps: f32
+};
+
+@group(0) @binding(0)
+var<storage, read_write> src: array<f32>;
+
+var<workgroup> scratch: array<f32, WG_SIZE>;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3<u32>,
+        @builtin(local_invocation_id) lid: vec3<u32>) {
+
+    // one thread per row
+    var i = wid.x;
+    let i3 = i / (params.ne2 * params.ne1);
+    i = i % (params.ne2 * params.ne1);
+    let i2 = i / params.ne1;
+    let i1 = i % params.ne1;
+    let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
+    let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
+
+    let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
+
+    var sum = 0.0f;
+    var col = lid.x;
+    for (var j: u32 = 0; j < elems; j++) {
+        if (col >= params.ne0) {
+            break;
+        }
+        sum += pow(src[i_src_row + col], 2.0);
+        col += WG_SIZE;
+    }
+
+    scratch[lid.x] = sum;
+    workgroupBarrier();
+    var offset: u32 = WG_SIZE / 2;
+    while (offset > 0) {
+        if (lid.x < offset) {
+            scratch[lid.x] += scratch[lid.x + offset];
+        }
+        offset = offset / 2;
+        workgroupBarrier();
+    }
+    sum = scratch[0];
+
+#ifdef OP_RMS_NORM
+    let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
+#elif OP_L2_NORM
+    let scale = 1.0/max(sqrt(sum), params.eps);
+#endif
+    col = lid.x;
+    for (var j: u32 = 0; j < elems; j++) {
+        if (col >= params.ne0) {
+            break;
+        }
+        update(i_src_row + col, i_dst_row + col, scale);
+        col += WG_SIZE;
+    }
+}