]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
py : fix internlm2-hf convert to gguf (#5305)
authorGuoteng <redacted>
Mon, 5 Feb 2024 09:04:06 +0000 (17:04 +0800)
committerGitHub <redacted>
Mon, 5 Feb 2024 09:04:06 +0000 (11:04 +0200)
* py : fix internlm2-hf convert to gguf

* ggml-ci

convert-hf-to-gguf.py

index a6ffd128b6bbbf7d2340752d263e6031be1750fc..5e343742d24847bb7940e68f401859259f135b4e 100755 (executable)
@@ -1416,8 +1416,32 @@ class InternLM2Model(Model):
         self.gguf_writer.add_add_space_prefix(add_prefix)
 
         special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
+        old_eos = special_vocab.special_token_ids["eos"]
+        if "chat" in os.path.basename(self.dir_model.absolute()):
+            # For the chat model, we replace the eos with '<|im_end|>'.
+            special_vocab.special_token_ids["eos"] = self._try_get_sft_eos(tokenizer)
+            print(f"Replace eos:{old_eos} with a special token:{special_vocab.special_token_ids['eos']} \
+in chat mode so that the conversation can end normally.")
+
         special_vocab.add_to_gguf(self.gguf_writer)
 
+    def _try_get_sft_eos(self, tokenizer):
+        unused_145_list = tokenizer.encode('[UNUSED_TOKEN_145]')
+        im_end_list = tokenizer.encode('<|im_end|>')
+        assert (len(unused_145_list) == 1) ^ (len(im_end_list) == 1)
+        if len(unused_145_list) == 1:
+            eos_token = unused_145_list[0]
+        if len(im_end_list) == 1:
+            eos_token = im_end_list[0]
+        return eos_token
+
+    def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int):
+        if n_head_kv is not None and n_head != n_head_kv:
+            n_head = n_head_kv
+        return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
+                .swapaxes(1, 2)
+                .reshape(weights.shape))
+
     def set_gguf_parameters(self):
         self.gguf_writer.add_name("InternLM2")
         self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
@@ -1486,8 +1510,9 @@ class InternLM2Model(Model):
                 qkv = data_torch
                 qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
                 q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
-                q = rearrange(q, " o g n i ->  o (g n i)").T
-                k = rearrange(k, " o g n i ->  o (g n i)").T
+                # The model weights of q and k equire additional reshape.
+                q = self._hf_permute_qk(rearrange(q, " o g n i ->  o (g n i)").T, num_heads, num_heads)
+                k = self._hf_permute_qk(rearrange(k, " o g n i ->  o (g n i)").T, num_heads, num_kv_heads)
                 v = rearrange(v, " o g n i ->  o (g n i)").T
                 self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wq.weight", q)
                 self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wk.weight", k)