]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : add epsilon as a parameter for group_norm (#8818)
authorMolly Sophia <redacted>
Tue, 6 Aug 2024 07:26:46 +0000 (15:26 +0800)
committerGitHub <redacted>
Tue, 6 Aug 2024 07:26:46 +0000 (10:26 +0300)
Signed-off-by: Molly Sophia <redacted>
ggml/include/ggml.h
ggml/src/ggml-cann/aclnn_ops.cpp
ggml/src/ggml-cuda/norm.cu
ggml/src/ggml-metal.m
ggml/src/ggml-sycl/norm.cpp
ggml/src/ggml.c
tests/test-backend-ops.cpp

index a9e88e592d51c8b89a6be84855c9de751566ff9f..15602a96df7ad3ef4df675d2a786a51d888c6cdb 100644 (file)
@@ -1140,16 +1140,17 @@ extern "C" {
 
     // group normalize along ne0*ne1*n_groups
     // used in stable-diffusion
-    // TODO: eps is hardcoded to 1e-6 for now
     GGML_API struct ggml_tensor * ggml_group_norm(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            int                   n_groups);
+            int                   n_groups,
+            float                 eps);
 
     GGML_API struct ggml_tensor * ggml_group_norm_inplace(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            int                   n_groups);
+            int                   n_groups,
+            float                 eps);
 
     // a - x
     // b - dy
index 171439132ff2a6110c2f40fe807ae5203ffef3fb..8c4132f5bb7ad74b0f1544777b978541b23fc879 100644 (file)
@@ -464,9 +464,11 @@ void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     aclTensor* acl_src = ggml_cann_create_tensor(src);
     aclTensor* acl_dst = ggml_cann_create_tensor(dst);
 
-    const float eps = 1e-6f;  // TODO: make this a parameter
     int n_groups = dst->op_params[0];
 
+    float eps;
+    memcpy(&eps, dst->op_params + 1, sizeof(float));
+
     uint64_t workspaceSize = 0;
     aclOpExecutor* executor;
     void* workspaceAddr = nullptr;
index 30866d51274fb2c6a7f0db4301d216dff2320464..133e219f0aeda890bd2de2026c424e882e6643e8 100644 (file)
@@ -142,8 +142,7 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
     }
 }
 
-static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) {
-    static const float eps = 1e-6f;
+static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
     if (group_size < 1024) {
         const dim3 block_dims(WARP_SIZE, 1, 1);
         group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
@@ -196,8 +195,12 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
     int num_groups = dst->op_params[0];
+
+    float eps;
+    memcpy(&eps, dst->op_params + 1, sizeof(float));
+
     int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
-    group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], group_size, ggml_nelements(src0), stream);
+    group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
 }
 
 void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
index 48b8131312a3ef82ed5662975e200e6ef776f220..b512eb0be132e7dfd5dac5d7e4dcee0f1fce5fe9 100644 (file)
@@ -2229,10 +2229,8 @@ static enum ggml_status ggml_metal_graph_compute(
                         GGML_ASSERT(ne00 % 4 == 0);
                         GGML_ASSERT(ggml_is_contiguous(src0));
 
-                        //float eps;
-                        //memcpy(&eps, dst->op_params, sizeof(float));
-
-                        const float eps = 1e-6f; // TODO: temporarily hardcoded
+                        float eps;
+                        memcpy(&eps, dst->op_params + 1, sizeof(float));
 
                         const int32_t n_groups = ((int32_t *) dst->op_params)[0];
 
index cccf87d069a31a57edf005a0c38a02a73a907ffa..b3159b9d1b94d63db4ea421171da0fce372d8f50 100644 (file)
@@ -225,9 +225,8 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
 }
 
 static void group_norm_f32_sycl(const float* x, float* dst,
-    const int num_groups, const int group_size,
+    const int num_groups, const float eps, const int group_size,
     const int ne_elements, queue_ptr stream, int device) {
-    static const float eps = 1e-6f;
     if (group_size < 1024) {
         const sycl::range<3> block_dims(1, 1, WARP_SIZE);
         stream->submit([&](sycl::handler& cgh) {
@@ -343,8 +342,12 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
 
     int num_groups = dst->op_params[0];
+
+    float eps;
+    memcpy(&eps, dst->op_params + 1, sizeof(float));
+
     int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
-    group_norm_f32_sycl(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
+    group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
 
     (void)src1;
     (void)dst;
index 910981e4a37ba680f16937e8df2f0c6c295ae81f..daceec4145b7eca01f494d0d41b094adda1bd70a 100644 (file)
@@ -5374,6 +5374,7 @@ static struct ggml_tensor * ggml_group_norm_impl(
     struct ggml_context * ctx,
     struct ggml_tensor * a,
     int n_groups,
+    float eps,
     bool inplace) {
 
     bool is_node = false;
@@ -5384,7 +5385,8 @@ static struct ggml_tensor * ggml_group_norm_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    result->op_params[0] = n_groups;
+    ggml_set_op_params_i32(result, 0, n_groups);
+    ggml_set_op_params_f32(result, 1, eps);
 
     result->op = GGML_OP_GROUP_NORM;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5396,15 +5398,17 @@ static struct ggml_tensor * ggml_group_norm_impl(
 struct ggml_tensor * ggml_group_norm(
     struct ggml_context * ctx,
     struct ggml_tensor * a,
-    int n_groups) {
-    return ggml_group_norm_impl(ctx, a, n_groups, false);
+    int n_groups,
+    float eps) {
+    return ggml_group_norm_impl(ctx, a, n_groups, eps, false);
 }
 
 struct ggml_tensor * ggml_group_norm_inplace(
     struct ggml_context * ctx,
     struct ggml_tensor * a,
-    int n_groups) {
-    return ggml_group_norm_impl(ctx, a, n_groups, true);
+    int n_groups,
+    float eps) {
+    return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
 }
 
 // ggml_mul_mat
@@ -12095,10 +12099,11 @@ static void ggml_compute_forward_group_norm_f32(
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
-    const float eps = 1e-6f; // TODO: make this a parameter
-
     // TODO: optimize
 
+    float eps;
+    memcpy(&eps, dst->op_params + 1, sizeof(float));
+
     int n_channels = src0->ne[2];
     int n_groups = dst->op_params[0];
     int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
index 54cef05c3de3d376ce548dc4264498f325e5f561..2f4117a627de0f5582626da0893558237934aeb0 100644 (file)
@@ -1511,6 +1511,7 @@ struct test_group_norm : public test_case {
     const ggml_type type;
     const std::array<int64_t, 4> ne;
     const int32_t num_groups;
+    const float eps;
 
     std::string vars() override {
         return VARS_TO_STR3(type, ne, num_groups);
@@ -1518,12 +1519,13 @@ struct test_group_norm : public test_case {
 
     test_group_norm(ggml_type type = GGML_TYPE_F32,
             std::array<int64_t, 4> ne = {64, 64, 320, 1},
-            int32_t num_groups = 32)
-        : type(type), ne(ne), num_groups(num_groups) {}
+            int32_t num_groups = 32,
+            float eps = 1e-6f)
+        : type(type), ne(ne), num_groups(num_groups), eps(eps) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_tensor * out = ggml_group_norm(ctx, a, num_groups);
+        ggml_tensor * out = ggml_group_norm(ctx, a, num_groups, eps);
         return out;
     }
 };