]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add RMS norm and use it (#187)
authorhoangmit <redacted>
Wed, 15 Mar 2023 22:41:38 +0000 (18:41 -0400)
committerGitHub <redacted>
Wed, 15 Mar 2023 22:41:38 +0000 (00:41 +0200)
* add ggml_rms_norm

* update op num

ggml.c
ggml.h
main.cpp

diff --git a/ggml.c b/ggml.c
index a0c0dd03b01d400a40f55f7a3b20218a75ebfe4c..eee54f7ffc89029149567a1f522b4077621bd2f2 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -2069,6 +2069,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "GELU",
     "SILU",
     "NORM",
+    "RMS_NORM",
 
     "MUL_MAT",
 
@@ -2089,7 +2090,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "FLASH_FF",
 };
 
-static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34");
+static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -2112,6 +2113,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "gelu(x)",
     "silu(x)",
     "norm(x)",
+    "rms_norm(x)",
 
     "X*Y",
 
@@ -2132,7 +2134,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "flash_ff(x)",
 };
 
-static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34");
+static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
 
 //
 // ggml object
@@ -3618,6 +3620,39 @@ struct ggml_tensor * ggml_norm_inplace(
     return ggml_norm_impl(ctx, a, true);
 }
 
+struct ggml_tensor * ggml_rms_norm_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_RMS_NORM;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL; // TODO: maybe store epsilon here?
+
+    return result;
+}
+
+struct ggml_tensor * ggml_rms_norm(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_rms_norm_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_rms_norm_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_rms_norm_impl(ctx, a, true);
+}
+
 // ggml_mul_mat
 
 struct ggml_tensor * ggml_mul_mat(
@@ -5406,6 +5441,87 @@ static void ggml_compute_forward_norm(
     }
 }
 
+static void ggml_compute_forward_rms_norm_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+    const ggml_float eps = 1e-5f; // TODO: make this a parameter
+
+    // TODO: optimize
+    for (int i03 = 0; i03 < ne03; i03++) {
+        for (int i02 = 0; i02 < ne02; i02++) {
+            for (int i01 = ith; i01 < ne01; i01 += nth) {
+                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                ggml_float mean = 0.0;
+                for (int i00 = 0; i00 < ne00; i00++) {
+                    mean += x[i00] * x[i00];
+                }
+
+                mean /= ne00;
+
+                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+                
+                memcpy(y, x, ne00 * sizeof(float));
+                // for (int i00 = 0; i00 < ne00; i00++) {
+                //     y[i00] = x[i00];
+                // }
+
+                const float scale = 1.0/sqrt(mean + eps);
+
+                ggml_vec_scale_f32(ne00, y, scale);
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_rms_norm(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rms_norm_f32(params, src0, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_F16:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+
 // ggml_compute_forward_mul_mat
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
@@ -8522,6 +8638,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_norm(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_RMS_NORM:
+            {
+                ggml_compute_forward_rms_norm(params, tensor->src0, tensor);
+            } break;
         case GGML_OP_MUL_MAT:
             {
                 ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
@@ -8764,6 +8884,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_RMS_NORM:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_MUL_MAT:
             {
                 if (src0->grad) {
diff --git a/ggml.h b/ggml.h
index 7ce655c1b04c963f4450f2c99208695bfcac4ceb..bac4fe65c7ad5122551c01e011cabf053b349819 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -230,6 +230,7 @@ enum ggml_op {
     GGML_OP_GELU,
     GGML_OP_SILU,
     GGML_OP_NORM, // normalize
+    GGML_OP_RMS_NORM,
 
     GGML_OP_MUL_MAT,
 
@@ -482,6 +483,10 @@ struct ggml_tensor * ggml_norm(
         struct ggml_context * ctx,
         struct ggml_tensor  * a);
 
+struct ggml_tensor * ggml_rms_norm(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
 // A: m rows, n columns
 // B: p rows, n columns (i.e. we transpose it internally)
 // result is m columns, p rows
index a812d0fa01842912e34c14aaf9166c42b6261807..ca0fca8b364559f4d29e4cf51c82c07d9e4ec375 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -588,7 +588,7 @@ bool llama_eval(
 
         // norm
         {
-            cur = ggml_norm(ctx0, inpL);
+            cur = ggml_rms_norm(ctx0, inpL);
 
             // cur = attention_norm*cur
             cur = ggml_mul(ctx0,
@@ -678,7 +678,7 @@ bool llama_eval(
         {
             // norm
             {
-                cur = ggml_norm(ctx0, inpFF);
+                cur = ggml_rms_norm(ctx0, inpFF);
 
                 // cur = ffn_norm*cur
                 cur = ggml_mul(ctx0,
@@ -713,7 +713,7 @@ bool llama_eval(
 
     // norm
     {
-        inpL = ggml_norm(ctx0, inpL);
+        inpL = ggml_rms_norm(ctx0, inpL);
 
         // inpL = norm*inpL
         inpL = ggml_mul(ctx0,