]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : support glm3 and glm4 (#8031)
authortoyer <redacted>
Sun, 7 Jul 2024 12:52:10 +0000 (20:52 +0800)
committerGitHub <redacted>
Sun, 7 Jul 2024 12:52:10 +0000 (15:52 +0300)
* add chatglm3-6b model support huggingface model:
 https://hf-mirror.com/THUDM/chatglm3-6b

Signed-off-by: XingXing Qiao <redacted>
* remove .rotary_pos_emb.inv_freq and unuse code for chatglm3 model

Signed-off-by: XingXing Qiao <redacted>
* fix lint error

Signed-off-by: XingXing Qiao <redacted>
* optimize convert-hf-to-gguf.py for chatglm model

Signed-off-by: XingXing Qiao <redacted>
* support glm-4-9b-chat

Signed-off-by: XingXing Qiao <redacted>
* fix eos tokens to glm4

* remove unused log

* add preprocess to chatglm3 and chatglm4

* add eos_id_list to llama.cpp

* fix code style

* fix code style

* fix conflicts

* fix conflicts

* Revert "add eos_id_list to llama.cpp"

This reverts commit 3a4d5790bfdc205c5b658204239f168fc21cc1a8.

* set <|endoftext|> as eos and <|user|> as eot

* fix chat template bug

* add comment to glm prefix and suffix

* fix conflicts and add rope_ratio & ChatGLMForConditionalGeneration

* fix chat template bug

* fix codestyle

* fix conflicts

* modified the general name of glm model

* fix conflicts

* remove prefix and suffix

* use normal glm4 chattempalte & use LLM_FFN_SWIGLU in phi3

* fix: resolve Flake8 errors in `convert-hf-to-gguf.py`

- Fix E302 by adding two blank lines before top-level function definitions
- Replace print statements to fix NP100
- Fix E303 by ensuring only one blank line between lines of code

* fix rope ratio to solve incorrect answers

* fix by comments

---------

Signed-off-by: XingXing Qiao <redacted>
Co-authored-by: XingXing Qiao <redacted>
Co-authored-by: Umpire2018 <redacted>
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
include/llama.h
src/llama.cpp
tests/test-chat-template.cpp

index 455eea883b93b25fd0f114386b49dacc45a3972f..6ee41d3a118e53e46bfad85e0678b64b6ca4279b 100755 (executable)
@@ -487,6 +487,9 @@ class Model:
         if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a":
             # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code
             res = "jina-v2-code"
+        if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
+            # ref: https://huggingface.co/THUDM/glm-4-9b-chat
+            res = "chatglm-bpe"
         if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee":
             # ref: https://huggingface.co/LumiOpen/Viking-7B
             res = "viking"
@@ -3176,6 +3179,190 @@ class JaisModel(Model):
         self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
 
 
+@Model.register("ChatGLMModel", "ChatGLMForConditionalGeneration")
+class ChatGLMModel(Model):
+    model_arch = gguf.MODEL_ARCH.CHATGLM
+
+    def set_vocab_chatglm3(self):
+        dir_model = self.dir_model
+        hparams = self.hparams
+        tokens: list[bytearray] = []
+        toktypes: list[int] = []
+        scores: list[float] = []
+
+        from transformers import AutoTokenizer
+        tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
+        vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab()))
+        assert max(tokenizer.get_vocab().values()) < vocab_size
+        role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
+        special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
+        for token_id in range(vocab_size):
+            piece = tokenizer._convert_id_to_token(token_id)
+            if token_id == 0:
+                piece = "<unk>"
+            elif token_id == 1:
+                piece = "<bos>"
+            elif token_id == 2:
+                piece = "<eos>"
+
+            text = piece.encode("utf-8")
+            score = 0.0
+            # Referencing the tokenizer Python implementation(https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py),
+            # it is only valid if it is less than tokenizer.tokenizer.sp_model.vocab_size()
+            if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size():
+                score = tokenizer.tokenizer.sp_model.get_score(token_id)
+
+            if len(piece) == 0:
+                text = f"[PAD{token_id}]".encode("utf-8")
+
+            if token_id >= tokenizer.tokenizer.sp_model.vocab_size():
+                if piece in special_tokens:
+                    # show special tokens in prompt
+                    toktype = SentencePieceTokenTypes.USER_DEFINED
+                else:
+                    toktype = SentencePieceTokenTypes.UNKNOWN
+                tokens.append(text)
+                scores.append(score)
+                toktypes.append(toktype)
+                continue
+
+            toktype = SentencePieceTokenTypes.NORMAL
+            if tokenizer.tokenizer.sp_model.is_unknown(token_id):
+                toktype = SentencePieceTokenTypes.UNKNOWN
+            elif tokenizer.tokenizer.sp_model.is_control(token_id):
+                toktype = SentencePieceTokenTypes.CONTROL
+            elif tokenizer.tokenizer.sp_model.is_unused(token_id):
+                toktype = SentencePieceTokenTypes.UNUSED
+            elif tokenizer.tokenizer.sp_model.is_byte(token_id):
+                toktype = SentencePieceTokenTypes.BYTE
+
+            tokens.append(text)
+            scores.append(score)
+            toktypes.append(toktype)
+
+        self.gguf_writer.add_tokenizer_model("llama")
+        # glm3 needs prefix and suffix formatted as:
+        # prompt = "[gMASK]sop<|user|>\n" + prompt + "<|assistant|>"
+        self.gguf_writer.add_tokenizer_pre("chatglm-spm")
+        self.gguf_writer.add_token_list(tokens)
+        self.gguf_writer.add_token_scores(scores)
+        self.gguf_writer.add_token_types(toktypes)
+
+        special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
+        special_vocab.add_to_gguf(self.gguf_writer)
+
+    @staticmethod
+    def token_bytes_to_string(b):
+        from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
+        byte_encoder = bytes_to_unicode()
+        return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
+
+    @staticmethod
+    def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]:
+        parts = [bytes([b]) for b in token]
+        while True:
+            min_idx = None
+            min_rank = None
+            for i, pair in enumerate(zip(parts[:-1], parts[1:])):
+                rank = mergeable_ranks.get(pair[0] + pair[1])
+                if rank is not None and (min_rank is None or rank < min_rank):
+                    min_idx = i
+                    min_rank = rank
+            if min_rank is None or (max_rank is not None and min_rank >= max_rank):
+                break
+            assert min_idx is not None
+            parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
+        return parts
+
+    def set_vocab(self):
+        if "THUDM/chatglm3-6b" in self.hparams.get("_name_or_path", ""):
+            self.set_vocab_chatglm3()
+            return
+
+        dir_model = self.dir_model
+        hparams = self.hparams
+        tokens: list[str] = []
+        toktypes: list[int] = []
+
+        from transformers import AutoTokenizer
+        tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
+        vocab_size = hparams["padded_vocab_size"]
+        assert max(tokenizer.get_vocab().values()) < vocab_size
+
+        tokpre = self.get_vocab_base_pre(tokenizer)
+
+        merges = []
+        vocab = {}
+        mergeable_ranks = tokenizer.mergeable_ranks
+        for token, rank in mergeable_ranks.items():
+            vocab[ChatGLMModel.token_bytes_to_string(token)] = rank
+            if len(token) == 1:
+                continue
+            merged = ChatGLMModel.bpe(mergeable_ranks, token, max_rank=rank)
+            assert len(merged) >= 2 and len(merged) <= 7
+            merges.append(' '.join(map(ChatGLMModel.token_bytes_to_string, merged)))
+
+        # for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
+        added_vocab = tokenizer.get_added_vocab()
+        reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()}
+
+        for i in range(vocab_size):
+            if i not in reverse_vocab:
+                tokens.append(f"[PAD{i}]")
+                toktypes.append(gguf.TokenType.USER_DEFINED)
+            elif reverse_vocab[i] in added_vocab:
+                tokens.append(reverse_vocab[i])
+                if tokenizer.added_tokens_decoder[i].special:
+                    toktypes.append(gguf.TokenType.CONTROL)
+                else:
+                    toktypes.append(gguf.TokenType.USER_DEFINED)
+            else:
+                tokens.append(reverse_vocab[i])
+                toktypes.append(gguf.TokenType.NORMAL)
+
+        self.gguf_writer.add_tokenizer_model("gpt2")
+        self.gguf_writer.add_tokenizer_pre(tokpre)
+        self.gguf_writer.add_token_list(tokens)
+        self.gguf_writer.add_token_types(toktypes)
+
+        special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
+        special_vocab.merges = merges
+        # only add special tokens when they were not already loaded from config.json
+        special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
+        special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
+        # this one is usually not in config.json anyway
+        special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
+        special_vocab.add_to_gguf(self.gguf_writer)
+
+    def set_gguf_parameters(self):
+        self.gguf_writer.add_name(self.hparams.get("_name_or_path").split("/")[1]) # THUDM/glm4-9b-chat or THUDM/chatglm3-6b
+        n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
+        n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
+        n_head_kv = self.hparams.get("multi_query_group_num", n_head)
+        self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
+        self.gguf_writer.add_embedding_length(n_embed)
+        self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", 4 * n_embed))
+        self.gguf_writer.add_block_count(self.hparams["num_layers"])
+        self.gguf_writer.add_head_count(n_head)
+        self.gguf_writer.add_head_count_kv(n_head_kv)
+        self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layernorm_epsilon"])
+        self.gguf_writer.add_file_type(self.ftype)
+        self.gguf_writer.add_rope_dimension_count(64)
+        self.gguf_writer.add_add_bos_token(False)
+        rope_freq = 10000
+        if "rope_ratio" in self.hparams:
+            rope_freq = rope_freq * self.hparams["rope_ratio"]
+        self.gguf_writer.add_rope_freq_base(rope_freq)
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        del bid  # unused
+
+        if name.endswith(".rotary_pos_emb.inv_freq"):
+            return []
+
+        name = name.removeprefix("transformer.")
+        return [(self.map_tensor_name(name), data_torch)]
+
 ###### CONVERSION LOGIC ######
 
 
index d78fb1e44b29ee7a72b6d6139370ad2306fcaad6..a95a44237e34826d7e4d2b910ba01b7ff2911c34 100644 (file)
@@ -120,7 +120,6 @@ class Keys:
         MIDDLE_ID            = "tokenizer.ggml.middle_token_id"
         EOT_ID               = "tokenizer.ggml.eot_token_id"
 
-
 #
 # recommended mapping of model tensor names for storage in gguf
 #
@@ -163,6 +162,7 @@ class MODEL_ARCH(IntEnum):
     OPENELM      = auto()
     ARCTIC       = auto()
     DEEPSEEK2    = auto()
+    CHATGLM      = auto()
     BITNET       = auto()
     T5           = auto()
     JAIS         = auto()
@@ -289,6 +289,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.OPENELM:        "openelm",
     MODEL_ARCH.ARCTIC:         "arctic",
     MODEL_ARCH.DEEPSEEK2:      "deepseek2",
+    MODEL_ARCH.CHATGLM:        "chatglm",
     MODEL_ARCH.BITNET:         "bitnet",
     MODEL_ARCH.T5:             "t5",
     MODEL_ARCH.JAIS:           "jais",
@@ -924,6 +925,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN_SHEXP,
         MODEL_TENSOR.FFN_UP_SHEXP,
     ],
+    MODEL_ARCH.CHATGLM : [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
     MODEL_ARCH.BITNET: [
         MODEL_TENSOR.ATTN_Q,
         MODEL_TENSOR.ATTN_K,
@@ -1020,6 +1033,9 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.ROPE_FREQS,
         MODEL_TENSOR.ATTN_ROT_EMBD,
     ],
+    MODEL_ARCH.CHATGLM: [
+        MODEL_TENSOR.ROPE_FREQS,
+    ],
 }
 
 #
index 6054f61595e135e1e0b7a27aabe0ec6806fd28cc..7264240f5e17a2967f4dd86febf4564d5fca7376 100644 (file)
@@ -24,6 +24,7 @@ class TensorNameMap:
             "backbone.embedding",                        # mamba
             "backbone.embeddings",                       # mamba-hf
             "transformer.in_out_embed",                  # Grok
+            "embedding.word_embeddings",                 # chatglm
             "transformer.token_embeddings",              # openelm
             "shared",                                    # t5
         ),
@@ -55,6 +56,7 @@ class TensorNameMap:
             "output",                    # llama-pth bloom internlm2
             "word_embeddings_for_head",  # persimmon
             "lm_head.linear",            # phi2
+            "output_layer",              # chatglm
         ),
 
         # Output norm
@@ -71,12 +73,14 @@ class TensorNameMap:
             "model.norm_f",                            # mamba-qbert
             "backbone.norm_f",                         # mamba
             "transformer.rms_norm",                    # Grok
+            "encoder.final_layernorm",                 # chatglm
             "transformer.norm",                        # openelm
         ),
 
         # Rope frequencies
         MODEL_TENSOR.ROPE_FREQS: (
             "rope.freqs",  # llama-pth
+            "rotary_pos_emb.inv_freq",  # chatglm
         ),
     }
 
@@ -101,6 +105,7 @@ class TensorNameMap:
             "backbone.layers.{bid}.norm",                           # mamba
             "transformer.decoder_layer.{bid}.rms_norm",             # Grok
             "transformer.blocks.{bid}.norm_attn_norm.norm_1",       # dbrx
+            "encoder.layers.{bid}.input_layernorm",                 # chatglm
             "transformer.layers.{bid}.attn_norm",                   # openelm
         ),
 
@@ -124,6 +129,7 @@ class TensorNameMap:
             "transformer.h.{bid}.mixer.Wqkv",                                      # phi2
             "encoder.layers.{bid}.attn.Wqkv",                                      # nomic-bert
             "model.layers.{bid}.self_attn.qkv_proj",                               # phi3
+            "encoder.layers.{bid}.self_attention.query_key_value",                 # chatglm
             "transformer.layers.{bid}.attn.qkv_proj",                              # openelm
         ),
 
@@ -135,7 +141,7 @@ class TensorNameMap:
             "transformer.h.{bid}.attn.q_proj",                           # gpt-j
             "model.layers.layers.{bid}.self_attn.q_proj",                # plamo
             "model.layers.{bid}.attention.wq",                           # internlm2
-            "transformer.decoder_layer.{bid}.multi_head_attention.query" # Grok
+            "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
         ),
 
         # Attention key
@@ -147,7 +153,7 @@ class TensorNameMap:
             "transformer.h.{bid}.attn.k",                              # refact
             "model.layers.layers.{bid}.self_attn.k_proj",              # plamo
             "model.layers.{bid}.attention.wk",                         # internlm2
-            "transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok
+            "transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
         ),
 
         # Attention value
@@ -182,6 +188,7 @@ class TensorNameMap:
             "encoder.layers.{bid}.attn.out_proj",                           # nomic-bert
             "transformer.decoder_layer.{bid}.multi_head_attention.linear",  # Grok
             "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj",        # dbrx
+            "encoder.layers.{bid}.self_attention.dense",                    # chatglm
             "transformer.layers.{bid}.attn.out_proj",                       # openelm
         ),
 
@@ -218,6 +225,7 @@ class TensorNameMap:
             "h.{bid}.ln_2",                                                  # gpt2
             "model.layers.{bid}.ffn_norm",                                   # internlm2
             "transformer.decoder_layer.{bid}.rms_norm_2",                    # Grok
+            "encoder.layers.{bid}.post_attention_layernorm",                 # chatglm
             "transformer.layers.{bid}.ffn_norm",                             # openelm
         ),
 
@@ -268,6 +276,7 @@ class TensorNameMap:
             "model.layers.{bid}.mlp.c_fc",                            # starcoder2
             "encoder.layer.{bid}.mlp.gated_layers_v",                 # jina-bert-v2
             "model.layers.{bid}.residual_mlp.w3",                     # arctic
+            "encoder.layers.{bid}.mlp.dense_h_to_4h",                 # chatglm
         ),
 
         MODEL_TENSOR.FFN_UP_EXP: (
@@ -337,6 +346,7 @@ class TensorNameMap:
             "transformer.layers.{bid}.ffn.proj_2",                    # openelm
             "model.layers.{bid}.residual_mlp.w2",                     # arctic
             "encoder.layer.{bid}.mlp.down_layer",                     # jina-bert-v2
+            "encoder.layers.{bid}.mlp.dense_4h_to_h",                 # chatglm
         ),
 
         MODEL_TENSOR.FFN_DOWN_EXP: (
index 865ace9944d02064a518bfdba4ebb794c127a52f..bb4b05ba636711618429afaa2c65c0cd273649d5 100644 (file)
@@ -88,8 +88,10 @@ extern "C" {
         LLAMA_VOCAB_PRE_TYPE_DBRX           = 13,
         LLAMA_VOCAB_PRE_TYPE_SMAUG          = 14,
         LLAMA_VOCAB_PRE_TYPE_PORO           = 15,
-        LLAMA_VOCAB_PRE_TYPE_VIKING         = 16,
-        LLAMA_VOCAB_PRE_TYPE_JAIS           = 17,
+        LLAMA_VOCAB_PRE_TYPE_CHATGLM3       = 16,
+        LLAMA_VOCAB_PRE_TYPE_CHATGLM4       = 17,
+        LLAMA_VOCAB_PRE_TYPE_VIKING         = 18,
+        LLAMA_VOCAB_PRE_TYPE_JAIS           = 19,
     };
 
     // note: these values should be synchronized with ggml_rope
index 1a65faef9bedc5a348fcc9e385ad0281acf39b3e..2b9ace28584572ff60a8c7844c6f06f9d5e4f4ec 100644 (file)
@@ -229,6 +229,7 @@ enum llm_arch {
     LLM_ARCH_OPENELM,
     LLM_ARCH_ARCTIC,
     LLM_ARCH_DEEPSEEK2,
+    LLM_ARCH_CHATGLM,
     LLM_ARCH_BITNET,
     LLM_ARCH_T5,
     LLM_ARCH_JAIS,
@@ -272,6 +273,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_OPENELM,         "openelm"      },
     { LLM_ARCH_ARCTIC,          "arctic"       },
     { LLM_ARCH_DEEPSEEK2,       "deepseek2"    },
+    { LLM_ARCH_CHATGLM,         "chatglm"      },
     { LLM_ARCH_BITNET,          "bitnet"       },
     { LLM_ARCH_T5,              "t5"           },
     { LLM_ARCH_JAIS,            "jais"         },
@@ -1205,6 +1207,21 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_FFN_UP_SHEXP,       "blk.%d.ffn_up_shexp" },
         },
     },
+    {
+        LLM_ARCH_CHATGLM,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+        },
+    },
     {
         LLM_ARCH_BITNET,
         {
@@ -2087,9 +2104,11 @@ enum e_model {
     MODEL_2_8B,
     MODEL_3B,
     MODEL_4B,
+    MODEL_6B,
     MODEL_6_9B,
     MODEL_7B,
     MODEL_8B,
+    MODEL_9B,
     MODEL_11B,
     MODEL_12B,
     MODEL_13B,
@@ -2115,7 +2134,6 @@ enum e_model {
     MODEL_16x12B,
     MODEL_10B_128x3_66B,
     MODEL_57B_A14B,
-    MODEL_9B,
     MODEL_27B,
 };
 
@@ -4490,9 +4508,11 @@ static const char * llama_model_type_name(e_model type) {
         case MODEL_2_8B:          return "2.8B";
         case MODEL_3B:            return "3B";
         case MODEL_4B:            return "4B";
+        case MODEL_6B:            return "6B";
         case MODEL_6_9B:          return "6.9B";
         case MODEL_7B:            return "7B";
         case MODEL_8B:            return "8B";
+        case MODEL_9B:            return "9B";
         case MODEL_11B:           return "11B";
         case MODEL_12B:           return "12B";
         case MODEL_13B:           return "13B";
@@ -4518,7 +4538,6 @@ static const char * llama_model_type_name(e_model type) {
         case MODEL_16x12B:        return "16x12B";
         case MODEL_10B_128x3_66B: return "10B+128x3.66B";
         case MODEL_57B_A14B:      return "57B.A14B";
-        case MODEL_9B:            return "9B";
         case MODEL_27B:           return "27B";
         default:                  return "?B";
     }
@@ -5124,6 +5143,15 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_CHATGLM:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                switch (hparams.n_layer) {
+                    case 28: model.type = e_model::MODEL_6B; break;
+                    case 40: model.type = e_model::MODEL_9B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_BITNET:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -5256,9 +5284,7 @@ static void llm_load_vocab(
             if (merges_keyidx == -1) {
                 throw std::runtime_error("cannot find tokenizer merges in model file\n");
             }
-
             const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
-
             for (int i = 0; i < n_merges; i++) {
                 const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
                 GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
@@ -5401,6 +5427,10 @@ static void llm_load_vocab(
                 tokenizer_pre == "poro-chat") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO;
                 vocab.tokenizer_clean_spaces = false;
+            } else if (
+                tokenizer_pre == "chatglm-bpe") {
+                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
+                vocab.special_bos_id  = -1;
             } else if (
                 tokenizer_pre == "viking") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING;
@@ -5525,7 +5555,6 @@ static void llm_load_vocab(
                 vocab.special_eot_id    = 107;
             }
         }
-
         try {
             vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
         } catch (const std::exception & e) {
@@ -7433,6 +7462,36 @@ static bool llm_load_tensors(
                         layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
                     }
                 } break;
+            case LLM_ARCH_CHATGLM:
+                {
+                    model.tok_embd   = ml.create_tensor(ctx_input,  tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + (hparams.n_embd_head_k << 2)});
+                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + (hparams.n_embd_head_k << 2)});
+
+                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
+                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+
+                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff * 2});
+
+                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -7657,6 +7716,7 @@ enum llm_ffn_op_type {
     LLM_FFN_GELU,
     LLM_FFN_RELU,
     LLM_FFN_RELU_SQR,
+    LLM_FFN_SWIGLU,
 };
 
 enum llm_ffn_gate_type {
@@ -7861,6 +7921,19 @@ static struct ggml_tensor * llm_build_ffn(
                 cur = ggml_sqr(ctx, cur);
                 cb(cur, "ffn_sqr(relu)", il);
             } break;
+        case LLM_FFN_SWIGLU:
+            {
+                // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
+                int64_t split_point = cur->ne[0] / 2;
+                struct ggml_tensor * x0 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], 0));
+                struct ggml_tensor * x1 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
+
+                x0 = ggml_silu(ctx, x0);
+                cb(cur, "ffn_silu", il);
+
+                cur = ggml_mul(ctx, x0, x1);
+                cb(cur, "ffn_mul", il);
+            } break;
     }
 
     if (type_gate == LLM_FFN_PAR) {
@@ -10709,19 +10782,12 @@ struct llm_build_context {
             // special-case: the up and gate tensors are merged into a single tensor
             // TOOD: support into llm_build_ffn
             {
-                struct ggml_tensor* up = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur);
-                cb(up, "ffn_up", il);
-
-                auto g = ggml_cont(ctx0, ggml_view_2d(ctx0, up, up->ne[0] / 2, up->ne[1], ggml_row_size(up->type, up->ne[0]), 0));
-                auto y = ggml_cont(ctx0, ggml_view_2d(ctx0, up, up->ne[0] / 2, up->ne[1], ggml_row_size(up->type, up->ne[0]), up->nb[1] / 2));
-
-                y = ggml_mul(ctx0, y, ggml_silu(ctx0, g));
-                cb(y, "ffn_gate", il);
-
-                auto down = ggml_mul_mat(ctx0, model.layers[il].ffn_down, y);
-                cb(down, "ffn_down", il);
-
-                cur = down;
+                cur = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        NULL,                      NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
@@ -13413,6 +13479,120 @@ struct llm_build_context {
 
         return gf;
     }
+
+    struct ggml_cgraph * build_chatglm() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm,
+                    NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                struct ggml_tensor * Qcur = nullptr;
+                struct ggml_tensor * Kcur = nullptr;
+                struct ggml_tensor * Vcur = nullptr;
+
+                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+                //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor);
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur_rope", il);
+
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur_rope", il);
+
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                        model.layers[il].wo, NULL,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            // Add the input
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // FF
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm,
+                        NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        NULL,                      NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
+                cb(cur, "ffn_out", il);
+
+            }
+
+            inpL = ggml_add(ctx0, cur, ffn_inp);
+            cb(inpL, "l_out", il);
+        }
+
+        cur = llm_build_norm(ctx0, inpL, hparams,
+                model.output_norm,
+                NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
 };
 
 static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -13644,6 +13824,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_deepseek2();
             } break;
+        case LLM_ARCH_CHATGLM:
+            {
+                result = llm.build_chatglm();
+            } break;
         case LLM_ARCH_BITNET:
             {
                 result = llm.build_bitnet();
@@ -15259,6 +15443,11 @@ struct llm_tokenizer_bpe {
                     " ?[^(\\s|.,!?…。,、।۔،)]+",
                 };
                 break;
+            case LLAMA_VOCAB_PRE_TYPE_CHATGLM4:
+                regex_exprs = {
+                    "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+                };
+                break;
             case LLAMA_VOCAB_PRE_TYPE_VIKING:
                 regex_exprs = {
                     "\\p{N}",
@@ -16160,7 +16349,6 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
                 if (add_special) {
                     tokenizer.append_bos(output);
                 }
-
                 for (const auto & fragment : fragment_buffer) {
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
                         auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
@@ -19151,6 +19339,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_OLMO:
         case LLM_ARCH_ARCTIC:
         case LLM_ARCH_DEEPSEEK2:
+        case LLM_ARCH_CHATGLM:
             return LLAMA_ROPE_TYPE_NORM;
 
         // the pairs of head values are offset by n_rot/2
@@ -20883,7 +21072,6 @@ int32_t llama_tokenize(
                         bool   add_special,
                         bool   parse_special) {
     auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_special, parse_special);
-
     if (n_tokens_max < (int) res.size()) {
         // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
         return -((int) res.size());
@@ -21302,6 +21490,25 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
         }
+    } else if (tmpl == "chatglm3" || tmpl_contains("[gMASK]sop")) {
+        // chatglm3-6b
+        ss << "[gMASK]" << "sop";
+        for (auto message : chat) {
+            std::string role(message->role);
+            ss << "<|" << role << "|>" << "\n " << message->content;
+        }
+        if (add_ass) {
+            ss << "<|assistant|>";
+        }
+    } else if (tmpl == "chaglm4" || tmpl_contains("[gMASK]<sop>")) {
+        ss << "[gMASK]" << "<sop>";
+        for (auto message : chat) {
+            std::string role(message->role);
+            ss << "<|" << role << "|>" << "\n" << message->content;
+        }
+        if (add_ass) {
+            ss << "<|assistant|>";
+        }
     } else if (tmpl == "minicpm" || tmpl_contains(u8"<用户>")) {
         // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
         for (auto message : chat) {
index 03f5369109716f07b62599dc1fb165313f085975..6583dd0b2b39b652216bb1f1c5eae58a6cbf698d 100644 (file)
@@ -58,6 +58,10 @@ int main(void) {
         "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
         //Phi-3-vision
         "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
+        // ChatGLM3
+        "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
+        // ChatGLM4
+        u8"[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
         // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
         u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
         // DeepSeek-V2
@@ -98,6 +102,10 @@ int main(void) {
         "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n   I am an assistant   <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
         //Phi-3-vision
         "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n   I am an assistant   <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
+        // ChatGLM3
+        "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n    I am an assistant   <|user|>\n Another question<|assistant|>",
+        // ChatGLM4
+        "[gMASK]<sop><|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n   I am an assistant   <|user|>\nAnother question<|assistant|>",
         // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
         u8"You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>",
         // DeepSeek-V2