]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server: add rms_norm_eps parameter (#2380)
authorslaren <redacted>
Tue, 25 Jul 2023 09:36:17 +0000 (11:36 +0200)
committerGitHub <redacted>
Tue, 25 Jul 2023 09:36:17 +0000 (12:36 +0300)
examples/server/server.cpp

index 4ad0ba9ecb910f1c057f518ef27b683943899916..83c03065a5d583daf150026f796532c165a8e224 100644 (file)
@@ -609,6 +609,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
     fprintf(stdout, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
     fprintf(stdout, "  -c N, --ctx-size N    size of the prompt context (default: %d)\n", params.n_ctx);
     fprintf(stdout, "  -gqa N, --gqa N       grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
+    fprintf(stdout, "  -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
     fprintf(stdout, "  --rope-freq-base N    RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
     fprintf(stdout, "  --rope-freq-scale N   RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
     fprintf(stdout, "  -b N, --batch-size N  batch size for prompt processing (default: %d)\n", params.n_batch);
@@ -734,6 +735,14 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
             }
             params.n_gqa = std::stoi(argv[i]);
         }
+        else if (arg == "-eps" || arg == "--rms-norm-eps") {
+            if (++i >= argc)
+            {
+                invalid_param = true;
+                break;
+            }
+            params.rms_norm_eps = std::stof(argv[i]);
+        }
         else if (arg == "--rope-freq-base")
         {
             if (++i >= argc)