]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix MiniCPM (#5392)
authorrunfuture <redacted>
Thu, 8 Feb 2024 10:36:19 +0000 (18:36 +0800)
committerGitHub <redacted>
Thu, 8 Feb 2024 10:36:19 +0000 (12:36 +0200)
* fix bug for norm_rms_eps missing

* to align with the same order as convert.py for model write

* fix: undo HF models permute tensor

* update for flake8 lint

convert-hf-to-gguf.py
llama.cpp

index 829d6836888ec67950e5587ec341cd310e008555..0d4ea03b44051c731d7cfc12df79e0c4a09234db 100755 (executable)
@@ -1078,17 +1078,76 @@ class MiniCPMModel(Model):
         self.gguf_writer.add_name("MiniCPM")
         self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
-        self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
         self.gguf_writer.add_block_count(block_count)
+        self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
+        self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
         self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
         self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
         self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
         self.gguf_writer.add_file_type(self.ftype)
-        self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
 
     def set_vocab(self):
         self._set_vocab_hf()
 
+    def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
+        if n_kv_head is not None and n_head != n_kv_head:
+            n_head //= n_kv_head
+
+        return (
+            weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
+            .swapaxes(1, 2)
+            .reshape(weights.shape)
+        )
+
+    def write_tensors(self):
+        block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
+        tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
+        n_head = self.hparams.get("num_attention_heads")
+        n_kv_head = self.hparams.get("num_key_value_heads")
+        for name, data_torch in self.get_tensors():
+            # we don't need these
+            if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
+                continue
+
+            old_dtype = data_torch.dtype
+
+            # convert any unsupported data types to float32
+            if data_torch.dtype not in (torch.float16, torch.float32):
+                data_torch = data_torch.to(torch.float32)
+
+            # HF models permute some of the tensors, so we need to undo that
+            if name.endswith(("q_proj.weight")):
+                data_torch = self._reverse_hf_permute(data_torch, n_head, n_head)
+            if name.endswith(("k_proj.weight")):
+                data_torch = self._reverse_hf_permute(data_torch, n_head, n_kv_head)
+
+            data = data_torch.squeeze().numpy()
+
+            # map tensor names
+            new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
+            if new_name is None:
+                print(f"Can not map tensor {name!r}")
+                sys.exit()
+
+            n_dims = len(data.shape)
+            data_dtype = data.dtype
+
+            # if f32 desired, convert any float16 to float32
+            if self.ftype == 0 and data_dtype == np.float16:
+                data = data.astype(np.float32)
+
+            # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
+            if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
+                data = data.astype(np.float32)
+
+            # if f16 desired, convert any float32 2-dim weight tensors to float16
+            if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
+                data = data.astype(np.float16)
+
+            print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
+
+            self.gguf_writer.add_tensor(new_name, data)
+
 
 class QwenModel(Model):
     @staticmethod
index f8f5796a4381474328bf9d5ceb453bacba81af57..552e0d02e56c55c89d79655def101230b266a62e 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -2947,6 +2947,8 @@ static void llm_load_hparams(
             } break;
         case LLM_ARCH_MINICPM:
             {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
                 switch (hparams.n_layer) {
                     case 40: model.type = e_model::MODEL_2B; break;
                     default: model.type = e_model::MODEL_UNKNOWN;