]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
make rms_norm_eps a parameter (#2374)
authorslaren <redacted>
Mon, 24 Jul 2023 15:57:12 +0000 (17:57 +0200)
committerGitHub <redacted>
Mon, 24 Jul 2023 15:57:12 +0000 (17:57 +0200)
* make rms_norm_eps a parameter

* add rms_norm_eps to command line

* fix baby llama, test-grad0

* use scientific notation for eps param in the help

ggml-ci

examples/baby-llama/baby-llama.cpp
examples/common.cpp
examples/common.h
examples/train-text-from-scratch/train-text-from-scratch.cpp
ggml-cuda.cu
ggml-metal.m
ggml.c
ggml.h
llama.cpp
llama.h
tests/test-grad0.c

index 4965881ecec22a02141ac1fe6bf456c1efcb2d20..f9dc0aaa61ed17a932c20534c33fcb2fd1146b84 100644 (file)
@@ -8,6 +8,8 @@
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
+static const float rms_norm_eps = 1e-6f;
+
 float frand() {
     return (float)rand()/(float)RAND_MAX;
 }
@@ -562,7 +564,7 @@ struct ggml_tensor * forward(
         // norm
         {
             // cur shape [n_embd,N,1,1]
-            cur = ggml_rms_norm(ctx0, inpL);
+            cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
 
             // cur = attention_norm*cur
             cur = ggml_mul(ctx0,
@@ -685,7 +687,7 @@ struct ggml_tensor * forward(
             // norm
             {
                 // cur shape [n_embd,N,1,1]
-                cur = ggml_rms_norm(ctx0, inpFF);
+                cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
 
                 // cur = ffn_norm*cur
                 // cur shape [n_embd,N,1,1]
@@ -729,7 +731,7 @@ struct ggml_tensor * forward(
     {
 
         // inpL shape [n_embd,N,1,1]
-        inpL = ggml_rms_norm(ctx0, inpL);
+        inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
 
         // inpL = norm*inpL
         // inpL shape [n_embd,N,1,1]
@@ -817,7 +819,7 @@ struct ggml_tensor * forward_batch(
         // norm
         {
             // cur shape [n_embd,N*n_batch,1,1]
-            cur = ggml_rms_norm(ctx0, inpL);
+            cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
             assert_shape_2d(cur, n_embd, N*n_batch);
 
             // cur = attention_norm*cur
@@ -981,7 +983,7 @@ struct ggml_tensor * forward_batch(
             // norm
             {
                 // cur shape [n_embd,N*n_batch,1,1]
-                cur = ggml_rms_norm(ctx0, inpFF);
+                cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
                 assert_shape_2d(cur, n_embd, N*n_batch);
 
                 // cur = ffn_norm*cur
@@ -1034,7 +1036,7 @@ struct ggml_tensor * forward_batch(
     {
 
         // inpL shape [n_embd,N*n_batch,1,1]
-        inpL = ggml_rms_norm(ctx0, inpL);
+        inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
         assert_shape_2d(inpL, n_embd, N*n_batch);
 
         // inpL = norm*inpL
@@ -1104,7 +1106,7 @@ struct ggml_tensor * forward_lora(
         // norm
         {
             // cur shape [n_embd,N,1,1]
-            cur = ggml_rms_norm(ctx0, inpL);
+            cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
 
             // cur = attention_norm*cur
             cur = ggml_mul(ctx0,
@@ -1251,7 +1253,7 @@ struct ggml_tensor * forward_lora(
             // norm
             {
                 // cur shape [n_embd,N,1,1]
-                cur = ggml_rms_norm(ctx0, inpFF);
+                cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
 
                 // cur = ffn_norm*cur
                 // cur shape [n_embd,N,1,1]
@@ -1295,7 +1297,7 @@ struct ggml_tensor * forward_lora(
     {
 
         // inpL shape [n_embd,N,1,1]
-        inpL = ggml_rms_norm(ctx0, inpL);
+        inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
 
         // inpL = norm*inpL
         // inpL shape [n_embd,N,1,1]
index 779605f9d1cb797041795d5eb589aa1258e213fb..0e88a128ad1956e609d5d4a38bd70eadb00a405e 100644 (file)
@@ -177,6 +177,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.n_gqa = std::stoi(argv[i]);
+        } else if (arg == "-eps" || arg == "--rms-norm-eps") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.rms_norm_eps = std::stof(argv[i]);
         } else if (arg == "--rope-freq-base") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -519,6 +525,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stdout, "  -c N, --ctx-size N    size of the prompt context (default: %d)\n", params.n_ctx);
     fprintf(stdout, "  -b N, --batch-size N  batch size for prompt processing (default: %d)\n", params.n_batch);
     fprintf(stdout, "  -gqa N, --gqa N       grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
+    fprintf(stdout, "  -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
     fprintf(stdout, "  --top-k N             top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
     fprintf(stdout, "  --top-p N             top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
     fprintf(stdout, "  --tfs N               tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
@@ -615,6 +622,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
     lparams.n_ctx           = params.n_ctx;
     lparams.n_batch         = params.n_batch;
     lparams.n_gqa           = params.n_gqa;
+    lparams.rms_norm_eps    = params.rms_norm_eps;
     lparams.n_gpu_layers    = params.n_gpu_layers;
     lparams.main_gpu        = params.main_gpu;
     lparams.tensor_split    = params.tensor_split;
index 7086606bf1a09cc604c9a8d13c0f3db63f862f28..894a0850a68c6de1bf8de5d4a55b3d70a599cdc7 100644 (file)
 int32_t get_num_physical_cores();
 
 struct gpt_params {
-    uint32_t seed                           = -1;  // RNG seed
+    uint32_t seed                           = -1;   // RNG seed
     int32_t n_threads                       = get_num_physical_cores();
-    int32_t n_predict                       = -1;  // new tokens to predict
-    int32_t n_ctx                           = 512; // context size
-    int32_t n_batch                         = 512; // batch size for prompt processing (must be >=32 to use BLAS)
-    int32_t n_gqa                           = 1;   // grouped-query attention factor (TODO: move to hparams)
-    int32_t n_keep                          = 0;   // number of tokens to keep from initial prompt
-    int32_t n_chunks                        = -1;  // max number of chunks to process (-1 = unlimited)
-    int32_t n_gpu_layers                    = 0;   // number of layers to store in VRAM
-    int32_t main_gpu                        = 0;   // the GPU that is used for scratch and small tensors
-    float   tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
-    int32_t n_probs                         = 0;   // if greater than 0, output the probabilities of top n_probs tokens.
+    int32_t n_predict                       = -1;   // new tokens to predict
+    int32_t n_ctx                           = 512;  // context size
+    int32_t n_batch                         = 512;  // batch size for prompt processing (must be >=32 to use BLAS)
+    int32_t n_gqa                           = 1;    // grouped-query attention factor (TODO: move to hparams)
+    int32_t n_keep                          = 0;    // number of tokens to keep from initial prompt
+    int32_t n_chunks                        = -1;   // max number of chunks to process (-1 = unlimited)
+    int32_t n_gpu_layers                    = 0;    // number of layers to store in VRAM
+    int32_t main_gpu                        = 0;    // the GPU that is used for scratch and small tensors
+    float   tensor_split[LLAMA_MAX_DEVICES] = {0};  // how split tensors should be distributed across GPUs
+    int32_t n_probs                         = 0;    // if greater than 0, output the probabilities of top n_probs tokens.
+    float   rms_norm_eps                    = 1e-6; // rms norm epsilon
     float   rope_freq_base                  = 10000.0f; // RoPE base frequency
     float   rope_freq_scale                 = 1.0f;     // RoPE frequency scaling factor
 
index 449b4e9ecdd549315dd9c952e5adf692f67bd89e..4bbf6b7826c62205324febcca1bda8e1165aefb1 100644 (file)
@@ -16,6 +16,8 @@
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
+static const float rms_norm_eps = 1e-6f;
+
 struct random_normal_distribution {
     std::mt19937 gen;
     std::normal_distribution<float> rd;
@@ -439,7 +441,7 @@ struct ggml_tensor * forward(
         // norm
         {
             // cur shape [n_embd,N,1,1]
-            cur = ggml_rms_norm(ctx0, inpL);
+            cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
 
             // cur = attention_norm*cur
             cur = ggml_mul(ctx0,
@@ -562,7 +564,7 @@ struct ggml_tensor * forward(
             // norm
             {
                 // cur shape [n_embd,N,1,1]
-                cur = ggml_rms_norm(ctx0, inpFF);
+                cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
 
                 // cur = ffn_norm*cur
                 // cur shape [n_embd,N,1,1]
@@ -606,7 +608,7 @@ struct ggml_tensor * forward(
     {
 
         // inpL shape [n_embd,N,1,1]
-        inpL = ggml_rms_norm(ctx0, inpL);
+        inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
 
         // inpL = norm*inpL
         // inpL shape [n_embd,N,1,1]
@@ -694,7 +696,7 @@ struct ggml_tensor * forward_batch(
         // norm
         {
             // cur shape [n_embd,N*n_batch,1,1]
-            cur = ggml_rms_norm(ctx0, inpL);
+            cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
             assert_shape_2d(cur, n_embd, N*n_batch);
 
             // cur = attention_norm*cur
@@ -857,7 +859,7 @@ struct ggml_tensor * forward_batch(
             // norm
             {
                 // cur shape [n_embd,N*n_batch,1,1]
-                cur = ggml_rms_norm(ctx0, inpFF);
+                cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
                 assert_shape_2d(cur, n_embd, N*n_batch);
 
                 // cur = ffn_norm*cur
@@ -910,7 +912,7 @@ struct ggml_tensor * forward_batch(
     {
 
         // inpL shape [n_embd,N*n_batch,1,1]
-        inpL = ggml_rms_norm(ctx0, inpL);
+        inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
         assert_shape_2d(inpL, n_embd, N*n_batch);
 
         // inpL = norm*inpL
@@ -979,7 +981,7 @@ struct ggml_tensor * forward_batch_wo_cache(
         // norm
         {
             // cur shape [n_embd,N*n_batch,1,1]
-            cur = ggml_rms_norm(ctx0, inpL);
+            cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
             assert_shape_2d(cur, n_embd, N*n_batch);
 
             // cur = attention_norm*cur
@@ -1085,7 +1087,7 @@ struct ggml_tensor * forward_batch_wo_cache(
             // norm
             {
                 // cur shape [n_embd,N*n_batch,1,1]
-                cur = ggml_rms_norm(ctx0, inpFF);
+                cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
                 assert_shape_2d(cur, n_embd, N*n_batch);
 
                 // cur = ffn_norm*cur
@@ -1138,7 +1140,7 @@ struct ggml_tensor * forward_batch_wo_cache(
     {
 
         // inpL shape [n_embd,N*n_batch,1,1]
-        inpL = ggml_rms_norm(ctx0, inpL);
+        inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
         assert_shape_2d(inpL, n_embd, N*n_batch);
 
         // inpL = norm*inpL
@@ -1203,7 +1205,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
 
         // norm
         {
-            cur = ggml_rms_norm(ctx0, inpL);
+            cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
             assert_shape_2d(cur, n_embd, N*n_batch);
 
             // cur = attention_norm*cur
@@ -1267,7 +1269,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
         {
             // norm
             {
-                cur = ggml_rms_norm(ctx0, inpFF);
+                cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
                 assert_shape_2d(cur, n_embd, N*n_batch);
 
                 // cur = ffn_norm*cur
@@ -1311,7 +1313,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
     // norm
     {
 
-        inpL = ggml_rms_norm(ctx0, inpL);
+        inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
         assert_shape_2d(inpL, n_embd, N*n_batch);
 
         // inpL = norm*inpL
@@ -1603,7 +1605,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
         struct my_llama_layer & layer = model->layers[il];
         // tensors with values necessary for backward pass are in persistent buf(-1)
         // other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed.
-        use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm     (ctx0, cur));                                    assert_shape_2d(t02, n_embd, N*n_batch);
+        use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm     (ctx0, cur, rms_norm_eps));                      assert_shape_2d(t02, n_embd, N*n_batch);
         use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat       (ctx0, layer.attention_norm, t02));              assert_shape_2d(t03, n_embd, N*n_batch);
         use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul          (ctx0, t02, t03));                               assert_shape_2d(t04, n_embd, N*n_batch);
         use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat      (ctx0, layer.wq, t04));                          assert_shape_2d(t05, n_embd, N*n_batch);
@@ -1623,7 +1625,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
         use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d   (ctx0, t18, n_embd, N*n_batch));                 assert_shape_2d(t19, n_embd, N*n_batch);
         use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat      (ctx0, layer.wo, t19));                          assert_shape_2d(t20, n_embd, N*n_batch);
         use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add          (ctx0, t20, cur));                               assert_shape_2d(t21, n_embd, N*n_batch);
-        use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm     (ctx0, t21));                                    assert_shape_2d(t22, n_embd, N*n_batch);
+        use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm     (ctx0, t21, rms_norm_eps));                      assert_shape_2d(t22, n_embd, N*n_batch);
         use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat       (ctx0, layer.ffn_norm, t22));                    assert_shape_2d(t23, n_embd, N*n_batch);
         use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul          (ctx0, t23, t22));                               assert_shape_2d(t24, n_embd, N*n_batch);
         use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat      (ctx0, layer.w3, t24));                          assert_shape_2d(t25, n_ff, N*n_batch);
@@ -1666,7 +1668,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
     }
     clr_buf(0);
     use_buf(0);
-    struct ggml_tensor * t31   = expand(gf, ggml_rms_norm  (ctx0, cur));                       assert_shape_2d(t31, n_embd, N*n_batch);
+    struct ggml_tensor * t31   = expand(gf, ggml_rms_norm  (ctx0, cur, rms_norm_eps));         assert_shape_2d(t31, n_embd, N*n_batch);
     struct ggml_tensor * t32   = expand(gf, ggml_repeat    (ctx0, model->norm, t31));          assert_shape_2d(t32, n_embd, N*n_batch);
     struct ggml_tensor * t33   = expand(gf, ggml_mul       (ctx0, t32, t31));                  assert_shape_2d(t33, n_embd, N*n_batch);
     use_buf(-1);
index b8c98354da192ae5a476995e90b122a0c904d582..87a1660613613c35215e802a790a161542dd7f70 100644 (file)
@@ -332,12 +332,10 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
     }
 }
 
-static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
+static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
     const int tid = threadIdx.x;
 
-    const float eps = 1e-6f;
-
     float tmp = 0.0f; // partial sum for thread in warp
 
     for (int col = tid; col < ncols; col += WARP_SIZE) {
@@ -2122,10 +2120,10 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
     norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
 }
 
-static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
     GGML_ASSERT(ncols % WARP_SIZE == 0);
     const dim3 block_dims(WARP_SIZE, 1, 1);
-    rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+    rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
 }
 
 static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
@@ -2876,8 +2874,11 @@ inline void ggml_cuda_op_rms_norm(
     const int64_t ne00 = src0->ne[0];
     const int64_t i01_diff = i01_high - i01_low;
 
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
     // compute
-    rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+    rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main);
 
     (void) src1;
     (void) dst;
index 1fd6e857ffe6138dc2a5be702bebcd57eb5c6fc5..c1db3d16573be8cff5229eceb4da314b36197a3b 100644 (file)
@@ -812,7 +812,8 @@ void ggml_metal_graph_compute(
                                 encoder = [command_buffer computeCommandEncoder];
                             }
 
-                            const float eps = 1e-6f;
+                            float eps;
+                            memcpy(&eps, dst->op_params, sizeof(float));
 
                             const int nth = 512;
 
diff --git a/ggml.c b/ggml.c
index 960b8057709a987e8aa83395fbe51a5c597fe21f..11226c834de7bc8d1a453fe57516fd1dd699c834 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -5781,6 +5781,7 @@ struct ggml_tensor * ggml_norm_inplace(
 static struct ggml_tensor * ggml_rms_norm_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
+        float eps,
         bool inplace) {
     bool is_node = false;
 
@@ -5790,7 +5791,7 @@ static struct ggml_tensor * ggml_rms_norm_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    // TODO: maybe store epsilon here?
+    ggml_set_op_params(result, &eps, sizeof(eps));
 
     result->op   = GGML_OP_RMS_NORM;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5801,14 +5802,16 @@ static struct ggml_tensor * ggml_rms_norm_impl(
 
 struct ggml_tensor * ggml_rms_norm(
         struct ggml_context * ctx,
-        struct ggml_tensor  * a) {
-    return ggml_rms_norm_impl(ctx, a, false);
+        struct ggml_tensor  * a,
+        float  eps) {
+    return ggml_rms_norm_impl(ctx, a, eps, false);
 }
 
 struct ggml_tensor * ggml_rms_norm_inplace(
         struct ggml_context * ctx,
-        struct ggml_tensor  * a) {
-    return ggml_rms_norm_impl(ctx, a, true);
+        struct ggml_tensor  * a,
+        float eps) {
+    return ggml_rms_norm_impl(ctx, a, eps, true);
 }
 
 struct ggml_tensor * ggml_rms_norm_back(
@@ -10131,7 +10134,8 @@ static void ggml_compute_forward_rms_norm_f32(
 
     GGML_TENSOR_UNARY_OP_LOCALS;
 
-    const float eps = 1e-6f; // TODO: make this a parameter
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
 
     // TODO: optimize
     for (int64_t i03 = 0; i03 < ne03; i03++) {
diff --git a/ggml.h b/ggml.h
index de44fba9e0961886ba2d35046e60dc7934f05961..1870b62e8aa1fff1ee13b97aa47e5be8aec8b683 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -866,14 +866,17 @@ extern "C" {
 
     GGML_API struct ggml_tensor * ggml_rms_norm(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a);
+            struct ggml_tensor  * a,
+            float                 eps);
 
     GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a);
+            struct ggml_tensor  * a,
+            float                 eps);
 
     // a - x
     // b - dy
+    // TODO: update with configurable eps
     GGML_API struct ggml_tensor * ggml_rms_norm_back(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
index 0288f7e1f65bdbb7254f6f149705642501d1f6ea..b42b41008654422e213eac660398024b8caee33a 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -186,6 +186,7 @@ struct llama_hparams {
     // LLaMAv2
     // TODO: load from model data hparams
     float f_ffn_mult = 1.0f;
+    float f_rms_norm_eps = 1e-6f;
 
     float rope_freq_base  = 10000.0f;
     float rope_freq_scale = 1.0f;
@@ -869,6 +870,7 @@ struct llama_context_params llama_context_default_params() {
         /*.n_ctx                       =*/ 512,
         /*.n_batch                     =*/ 512,
         /*.n_gqa                       =*/ 1,
+        /*.rms_norm_eps                =*/ 1e-6f,
         /*.gpu_layers                  =*/ 0,
         /*.main_gpu                    =*/ 0,
         /*.tensor_split                =*/ nullptr,
@@ -1000,6 +1002,7 @@ static void llama_model_load_internal(
         int n_ctx,
         int n_batch,
         int n_gqa,
+        float rms_norm_eps,
         int n_gpu_layers,
         int main_gpu,
         const float * tensor_split,
@@ -1024,6 +1027,9 @@ static void llama_model_load_internal(
 
     auto & hparams = model.hparams;
 
+    // TODO: read from file
+    hparams.f_rms_norm_eps = rms_norm_eps;
+
     {
         switch (hparams.n_layer) {
             case 26: model.type = e_model::MODEL_3B; break;
@@ -1072,6 +1078,7 @@ static void llama_model_load_internal(
         fprintf(stderr, "%s: n_layer    = %u\n",   __func__, hparams.n_layer);
         fprintf(stderr, "%s: n_rot      = %u\n",   __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
         fprintf(stderr, "%s: n_gqa      = %u\n",   __func__, hparams.n_gqa());
+        fprintf(stderr, "%s: rnorm_eps  = %.1e\n", __func__, hparams.f_rms_norm_eps);
         fprintf(stderr, "%s: n_ff       = %u\n",   __func__, n_ff);
         fprintf(stderr, "%s: freq_base  = %.1f\n", __func__, hparams.rope_freq_base);
         fprintf(stderr, "%s: freq_scale = %g\n",   __func__, hparams.rope_freq_scale);
@@ -1330,6 +1337,7 @@ static bool llama_model_load(
         int n_ctx,
         int n_batch,
         int n_gqa,
+        float rms_norm_eps,
         int n_gpu_layers,
         int main_gpu,
         const float * tensor_split,
@@ -1343,7 +1351,7 @@ static bool llama_model_load(
         llama_progress_callback progress_callback,
         void *progress_callback_user_data) {
     try {
-        llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
+        llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
                                   use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
         return true;
     } catch (const std::exception & err) {
@@ -1396,10 +1404,12 @@ static bool llama_eval_internal(
     const int64_t n_vocab     = hparams.n_vocab;
     const int64_t n_embd_gqa  = hparams.n_embd_gqa();
 
+
     LLAMA_ASSERT(n_embd_head == hparams.n_rot);
 
     const float freq_base  = hparams.rope_freq_base;
     const float freq_scale = hparams.rope_freq_scale;
+    const float rms_norm_eps = hparams.f_rms_norm_eps;
 
     const int n_gpu_layers = model.n_gpu_layers;
 
@@ -1479,7 +1489,7 @@ static bool llama_eval_internal(
 
         // norm
         {
-            cur = ggml_rms_norm(ctx0, inpL);
+            cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
             offload_func(cur);
             ggml_set_name(cur, "rms_norm_0");
 
@@ -1627,7 +1637,7 @@ static bool llama_eval_internal(
         {
             // norm
             {
-                cur = ggml_rms_norm(ctx0, inpFF);
+                cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
                 offload_func(cur);
                 ggml_set_name(cur, "rms_norm_1");
 
@@ -1680,7 +1690,7 @@ static bool llama_eval_internal(
 
     // norm
     {
-        cur = ggml_rms_norm(ctx0, inpL);
+        cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
         offload_func_nr(cur);
         ggml_set_name(cur, "rms_norm_2");
 
@@ -3084,7 +3094,7 @@ struct llama_model * llama_load_model_from_file(
 
     ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
 
-    if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.n_gpu_layers,
+    if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.rms_norm_eps, params.n_gpu_layers,
                 params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram,
                 memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback,
                 params.progress_callback_user_data)) {
diff --git a/llama.h b/llama.h
index 81a30e16b2a300e890393dfffc8e97cb9c140092..843b0bf5f7f5a5bdc34d6f6e0a8217b0f483ff36 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -87,6 +87,7 @@ extern "C" {
         int32_t  n_ctx;        // text context
         int32_t  n_batch;      // prompt processing batch size
         int32_t  n_gqa;        // grouped-query attention (TEMP - will be moved to model hparams)
+        float    rms_norm_eps; // rms norm epsilon (TEMP - will be moved to model hparams)
         int32_t  n_gpu_layers; // number of layers to store in VRAM
         int32_t  main_gpu;     // the GPU that is used for scratch and small tensors
 
index ef20bce516662e645395475e3dc6fdf01756b4ab..6d312216d58af7c89287696d00127966105bf7d7 100644 (file)
@@ -850,7 +850,7 @@ int main(int argc, const char ** argv) {
                     ggml_set_param(ctx0, x[i]);
                 }
 
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0]));
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f));
 
                 check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY);
             }