]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix llm_build_k_shift to use correct n_rot (#4889)
authorGeorgi Gerganov <redacted>
Fri, 12 Jan 2024 11:01:56 +0000 (13:01 +0200)
committerGitHub <redacted>
Fri, 12 Jan 2024 11:01:56 +0000 (13:01 +0200)
* llama : fix llm_build_k_shift to use correct n_rot

ggml-ci

* llama : always use hparams.n_rot for ggml_rope_custom

ggml-ci

* convert : fix persimmon conversion to write correct n_rot

common/common.cpp
convert-hf-to-gguf.py
gguf-py/gguf/tensor_mapping.py
llama.cpp

index b2cb0e257a817dfc980a6ea6031b15b9475fa611..3aefed01d30493d2a0efe5866485f06c3d42b763 100644 (file)
@@ -1055,6 +1055,9 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
 }
 
 static ggml_type kv_cache_type_from_str(const std::string & s) {
+    if (s == "f32") {
+        return GGML_TYPE_F32;
+    }
     if (s == "f16") {
         return GGML_TYPE_F16;
     }
index 203eaf64b3fc36a1a25dacb4fded02ea294c7102..813aeeed680f85bd7c8a6be01845cc5b9a71a4bd 100755 (executable)
@@ -817,10 +817,17 @@ class PersimmonModel(Model):
         hidden_size = self.hparams["hidden_size"]
 
         self.gguf_writer.add_name('persimmon-8b-chat')
+        self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
         self.gguf_writer.add_embedding_length(hidden_size)
         self.gguf_writer.add_block_count(block_count)
         self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
-        self.gguf_writer.add_rope_dimension_count(hidden_size // head_count)
+
+        # NOTE: not sure about this change - why does the model not have a rope dimension count when it is smaller
+        #       than the head size?
+        #       ref: https://github.com/ggerganov/llama.cpp/pull/4889
+        #self.gguf_writer.add_rope_dimension_count(hidden_size // head_count)
+        self.gguf_writer.add_rope_dimension_count(hidden_size // head_count // 2)
+
         self.gguf_writer.add_head_count(head_count)
         self.gguf_writer.add_head_count_kv(head_count_kv)
         self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
index 80c1d5449cc74ebf8dc5ea2f74811d770b36a194..24a0890378496f2a3d9e0c1805d41806d4971aa9 100644 (file)
@@ -57,6 +57,7 @@ class TensorNameMap:
             "transformer.norm_f",                      # mpt
             "ln_f",                                    # refact bloom qwen gpt2
             "language_model.encoder.final_layernorm",  # persimmon
+            "model.final_layernorm",                   # persimmon
             "lm_head.ln",                              # phi2
         ),
 
@@ -98,6 +99,7 @@ class TensorNameMap:
             "transformer.h.{bid}.self_attention.query_key_value",                  # falcon
             "h.{bid}.self_attention.query_key_value",                              # bloom
             "language_model.encoder.layers.{bid}.self_attention.query_key_value",  # persimmon
+            "model.layers.{bid}.self_attn.query_key_value",                        # persimmon
             "h.{bid}.attn.c_attn",                                                 # gpt2
             "transformer.h.{bid}.mixer.Wqkv",                                      # phi2
         ),
@@ -141,6 +143,7 @@ class TensorNameMap:
             "encoder.layer.{bid}.attention.output.dense",                # bert
             "transformer.h.{bid}.attn.out_proj",                         # gpt-j
             "language_model.encoder.layers.{bid}.self_attention.dense",  # persimmon
+            "model.layers.{bid}.self_attn.dense",                        # persimmon
             "h.{bid}.attn.c_proj",                                       # gpt2
             "transformer.h.{bid}.mixer.out_proj",                        # phi2
             "model.layers.layers.{bid}.self_attn.o_proj",                # plamo
@@ -184,6 +187,7 @@ class TensorNameMap:
             "encoder.layer.{bid}.intermediate.dense",                 # bert
             "transformer.h.{bid}.mlp.fc_in",                          # gpt-j
             "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h",  # persimmon
+            "model.layers.{bid}.mlp.dense_h_to_4h",                   # persimmon
             "transformer.h.{bid}.mlp.w1",                             # qwen
             "h.{bid}.mlp.c_fc",                                       # gpt2
             "transformer.h.{bid}.mlp.fc1",                            # phi2
@@ -225,6 +229,7 @@ class TensorNameMap:
             "encoder.layer.{bid}.output.dense",                       # bert
             "transformer.h.{bid}.mlp.fc_out",                         # gpt-j
             "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h",  # persimmon
+            "model.layers.{bid}.mlp.dense_4h_to_h",                   # persimmon
             "h.{bid}.mlp.c_proj",                                     # gpt2
             "transformer.h.{bid}.mlp.fc2",                            # phi2
             "model.layers.layers.{bid}.mlp.down_proj",                # plamo
@@ -237,10 +242,12 @@ class TensorNameMap:
 
         MODEL_TENSOR.ATTN_Q_NORM: (
             "language_model.encoder.layers.{bid}.self_attention.q_layernorm",
+            "model.layers.{bid}.self_attn.q_layernorm",                       # persimmon
         ),
 
         MODEL_TENSOR.ATTN_K_NORM: (
             "language_model.encoder.layers.{bid}.self_attention.k_layernorm",
+            "model.layers.{bid}.self_attn.k_layernorm",                       # persimmon
         ),
 
         MODEL_TENSOR.ROPE_FREQS: (
index d39ff94c7fae696ccba29c25cc5e16b09588fd77..0bab95563a226a9cb494b11a2f2271ec8fd0784d 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -4104,7 +4104,6 @@ static void llm_build_k_shift(
        struct ggml_cgraph * graph,
             llm_rope_type   type,
                   int64_t   n_ctx,
-                  int       n_rot,
                   float     freq_base,
                   float     freq_scale,
        const llm_build_cb & cb) {
@@ -4112,14 +4111,13 @@ static void llm_build_k_shift(
     const int64_t n_head_kv     = hparams.n_head_kv;
     const int64_t n_embd_head_k = hparams.n_embd_head_k;
     const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa();
+    const int32_t n_rot         = hparams.n_rot;
     const int32_t n_orig_ctx    = cparams.n_yarn_orig_ctx;
     const float   ext_factor    = cparams.yarn_ext_factor;
     const float   attn_factor   = cparams.yarn_attn_factor;
     const float   beta_fast     = cparams.yarn_beta_fast;
     const float   beta_slow     = cparams.yarn_beta_slow;
 
-    GGML_ASSERT(n_embd_head_k % n_rot == 0);
-
     struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx);
     cb(K_shift, "K_shift", -1);
 
@@ -4523,7 +4521,7 @@ struct llm_build_context {
 
         // shift the entire K-cache if needed
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -4561,14 +4559,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
-                    n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
+                    hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
-                    n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
+                    hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -4691,6 +4689,7 @@ struct llm_build_context {
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
@@ -4708,7 +4707,7 @@ struct llm_build_context {
 
         // shift the entire K-cache if needed
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -4734,12 +4733,12 @@ struct llm_build_context {
                     case MODEL_7B:
                         Qcur = ggml_rope_custom(
                             ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
-                            n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
+                            hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
                             ext_factor, attn_factor, beta_fast, beta_slow
                         );
                         Kcur = ggml_rope_custom(
                             ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
-                            n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
+                            hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
                             ext_factor, attn_factor, beta_fast, beta_slow
                         );
                         break;
@@ -4812,6 +4811,7 @@ struct llm_build_context {
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
@@ -4829,7 +4829,7 @@ struct llm_build_context {
 
         // shift the entire K-cache if needed
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -4870,13 +4870,13 @@ struct llm_build_context {
 
                 // using mode = 2 for neox mode
                 Qcur = ggml_rope_custom(
-                    ctx0, Qcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx,
+                    ctx0, Qcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_custom(
-                    ctx0, Kcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx,
+                    ctx0, Kcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -5033,9 +5033,8 @@ struct llm_build_context {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        const int64_t n_rot = n_embd_head_k / 2;
+        GGML_ASSERT(n_embd_head   == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head/2 == hparams.n_rot);
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
@@ -5052,7 +5051,7 @@ struct llm_build_context {
         cb(KQ_mask, "KQ_mask", -1);
 
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -5112,7 +5111,7 @@ struct llm_build_context {
 
                 // RoPE the first n_rot of q/k, pass the other half, and concat.
                 struct ggml_tensor * qrot = ggml_view_3d(
-                        ctx0, tmpq, n_rot, n_head, n_tokens,
+                        ctx0, tmpq, hparams.n_rot, n_head, n_tokens,
                         ggml_element_size(tmpq) * n_embd_head,
                         ggml_element_size(tmpq) * n_embd_head * n_head,
                         0
@@ -5120,7 +5119,7 @@ struct llm_build_context {
                 cb(qrot, "qrot", il);
 
                 struct ggml_tensor * krot = ggml_view_3d(
-                        ctx0, tmpk, n_rot, n_head, n_tokens,
+                        ctx0, tmpk, hparams.n_rot, n_head, n_tokens,
                         ggml_element_size(tmpk) * n_embd_head,
                         ggml_element_size(tmpk) * n_embd_head * n_head,
                         0
@@ -5129,29 +5128,29 @@ struct llm_build_context {
 
                 // get the second half of tmpq, e.g tmpq[n_rot:, :, :]
                 struct ggml_tensor * qpass = ggml_view_3d(
-                        ctx0, tmpq, n_rot, n_head, n_tokens,
+                        ctx0, tmpq, hparams.n_rot, n_head, n_tokens,
                         ggml_element_size(tmpq) * n_embd_head,
                         ggml_element_size(tmpq) * n_embd_head * n_head,
-                        ggml_element_size(tmpq) * n_rot
+                        ggml_element_size(tmpq) * hparams.n_rot
                         );
                 cb(qpass, "qpass", il);
 
                 struct ggml_tensor * kpass = ggml_view_3d(
-                        ctx0, tmpk, n_rot, n_head, n_tokens,
+                        ctx0, tmpk, hparams.n_rot, n_head, n_tokens,
                         ggml_element_size(tmpk) * n_embd_head,
                         ggml_element_size(tmpk) * n_embd_head * n_head,
-                        ggml_element_size(tmpk) * n_rot
+                        ggml_element_size(tmpk) * hparams.n_rot
                         );
                 cb(kpass, "kpass", il);
 
                 struct ggml_tensor * qrotated = ggml_rope_custom(
-                    ctx0, qrot, inp_pos, n_rot, 2, 0, n_orig_ctx,
+                    ctx0, qrot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(qrotated, "qrotated", il);
 
                 struct ggml_tensor * krotated = ggml_rope_custom(
-                    ctx0, krot, inp_pos, n_rot, 2, 0, n_orig_ctx,
+                    ctx0, krot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(krotated, "krotated", il);
@@ -5531,6 +5530,7 @@ struct llm_build_context {
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
@@ -5548,7 +5548,7 @@ struct llm_build_context {
 
         // shift the entire K-cache if needed
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, hparams.n_rot, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -5661,7 +5661,7 @@ struct llm_build_context {
 
         // shift the entire K-cache if needed
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -5693,13 +5693,13 @@ struct llm_build_context {
 
                 // using mode = 2 for neox mode
                 Qcur = ggml_rope_custom(
-                    ctx0, Qcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx,
+                    ctx0, Qcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_custom(
-                    ctx0, Kcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx,
+                    ctx0, Kcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -5778,7 +5778,7 @@ struct llm_build_context {
 
         // shift the entire K-cache if needed
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -5874,6 +5874,7 @@ struct llm_build_context {
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
@@ -5891,7 +5892,7 @@ struct llm_build_context {
 
         // shift the entire K-cache if needed
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -5917,13 +5918,13 @@ struct llm_build_context {
                 cb(Vcur, "Vcur", il);
 
                 Qcur = ggml_rope_custom(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, hparams.n_rot, n_head,    n_tokens), inp_pos,
                         n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_custom(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, hparams.n_rot, n_head_kv, n_tokens), inp_pos,
                         n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur", il);