]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix n_rot default (#8348)
authorGeorgi Gerganov <redacted>
Sun, 7 Jul 2024 11:59:02 +0000 (14:59 +0300)
committerGitHub <redacted>
Sun, 7 Jul 2024 11:59:02 +0000 (14:59 +0300)
ggml-ci

src/llama.cpp

index b39906fd53b6f11adfa0e928171f931929d78a39..1a65faef9bedc5a348fcc9e385ad0281acf39b3e 100644 (file)
@@ -4625,16 +4625,6 @@ static void llm_load_hparams(
 
     // non-transformer models do not have attention heads
     if (hparams.n_head() > 0) {
-        // sanity check for n_rot (optional)
-        hparams.n_rot = hparams.n_embd / hparams.n_head();
-
-        ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
-
-        if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
-            if (hparams.n_rot != hparams.n_embd / hparams.n_head()) {
-                throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head()));
-            }
-        }
         // gpt-neox n_rot = rotary_pct * (n_embd / n_head)
         // gpt-j n_rot = rotary_dim
 
@@ -4643,6 +4633,17 @@ static void llm_load_hparams(
 
         hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
         ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
+
+        // sanity check for n_rot (optional)
+        hparams.n_rot = hparams.n_embd_head_k;
+
+        ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
+
+        if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
+            if (hparams.n_rot != hparams.n_embd_head_k) {
+                throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
+            }
+        }
     } else {
         hparams.n_rot = 0;
         hparams.n_embd_head_k = 0;
@@ -11490,7 +11491,7 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_ext(
                         ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
-                        n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Qcur, "Qcur", il);
 
@@ -11499,7 +11500,7 @@ struct llm_build_context {
 
                 Kcur = ggml_rope_ext(
                         ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
-                        n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur", il);
 
@@ -11603,7 +11604,7 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_ext(
                         ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
-                        n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Qcur, "Qcur", il);
 
@@ -11612,7 +11613,7 @@ struct llm_build_context {
 
                 Kcur = ggml_rope_ext(
                         ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
-                        n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur", il);