]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model: support GLM 4.5 family of models (#14939)
authorSam <redacted>
Mon, 4 Aug 2025 18:29:25 +0000 (04:29 +1000)
committerGitHub <redacted>
Mon, 4 Aug 2025 18:29:25 +0000 (20:29 +0200)
* model: Add GLM 4.5 (#14921)

Co-authored-by: Sigbjørn Skjæret <redacted>
* Merge in PR suggestions

Co-authored-by: Sigbjørn Skjæret <redacted>
* model: Add GLM 4.5 family of models (#14921)

1. Updated tensor_mapping.py with NextN tensor mappings

- Added proper tensor mappings for all NextN/MTP tensors in /Users/samm/git/llama.cpp/gguf-py/gguf/tensor_mapping.py
- Added mappings for: eh_proj, embed_tokens, enorm, hnorm, shared_head.head, shared_head.norm

2. Added num_nextn_predict_layers configuration

- Added LLM_KV_NUM_NEXTN_PREDICT_LAYERS constant to llama-arch.h and llama-arch.cpp
- Added num_nextn_predict_layers field to llama_hparams struct
- Updated GLM4_MOE parameter loading in llama-model.cpp to read this parameter
- Modified tensor loading logic to conditionally load NextN tensors based on num_nextn_predict_layers
- Added GGUF writer support in gguf_writer.py with add_num_nextn_predict_layers() method
- Updated conversion script to extract and write this parameter from HuggingFace config

3. Added FIM tokens for GLM4_MOE

- Added GLM-4.5's FIM tokens to llama-vocab.cpp:
  - <|code_prefix|> for FIM_PRE
  - <|code_suffix|> for FIM_SUF
  - <|code_middle|> for FIM_MID

4. Removed manual NextN tensor handling

- Removed the special-case handling in convert_hf_to_gguf.py that manually mapped NextN tensors
- NextN tensors are now handled automatically through the proper tensor mapping system

* glm 4.5 update tensors names

* model: glm 4.5 apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* model: glm 4.5 apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <redacted>
* model: glm 4.5 apply suggestions from code review

* Apply suggestions from code review

* patch broken chat template

* typings fix

* add TENSOR_SKIP flag

Co-authored-by: Diego Devesa <redacted>
* Update src/llama-model-loader.h

Co-authored-by: Sigbjørn Skjæret <redacted>
---------

Co-authored-by: Sigbjørn Skjæret <redacted>
Co-authored-by: Diego Devesa <redacted>
15 files changed:
convert_hf_to_gguf.py
convert_hf_to_gguf_update.py
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/tensor_mapping.py
models/templates/README.md
src/llama-arch.cpp
src/llama-arch.h
src/llama-graph.cpp
src/llama-hparams.h
src/llama-kv-cache-unified.cpp
src/llama-model-loader.h
src/llama-model.cpp
src/llama-model.h
src/llama-vocab.cpp

index 9303a047694f566052f4eae26c7732a33ee3b0d8..a215f4ed7294a6632b4649538aa5c5f822f6a460 100755 (executable)
@@ -678,6 +678,9 @@ class TextModel(ModelBase):
         if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
             # ref: https://huggingface.co/THUDM/glm-4-9b-hf
             res = "glm4"
+        if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902":
+            # ref: https://huggingface.co/zai-org/GLM-4.5-Air
+            res = "glm4"
         if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
             # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
             res = "minerva-7b"
@@ -6696,6 +6699,139 @@ class Glm4Model(TextModel):
         return super().modify_tensors(data_torch, name, bid)
 
 
+@ModelBase.register("Glm4MoeForCausalLM")
+class Glm4MoeModel(TextModel):
+    model_arch = gguf.MODEL_ARCH.GLM4_MOE
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        # GLM4_MOE has num_hidden_layers + 1 actual layers (including NextN layer)
+        self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0)
+        self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
+
+    def set_vocab(self):
+        from transformers import AutoTokenizer
+
+        tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
+        special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
+        tokens, toktypes, tokpre = self.get_vocab_base()
+        self.gguf_writer.add_tokenizer_model("gpt2")
+        self.gguf_writer.add_tokenizer_pre(tokpre)
+        self.gguf_writer.add_token_list(tokens)
+        self.gguf_writer.add_token_types(toktypes)
+
+        # Special tokens
+        # Note: Using <|endoftext|> (151329) for eot causes endless generation
+        special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"])  # 151331
+        special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])  # 151336
+        special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
+        special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"])  # 151338
+
+        # Patch broken chat template
+        if isinstance(special_vocab.chat_template, str) and "visible_text(m.content).endswith" in special_vocab.chat_template:
+            special_vocab.chat_template = special_vocab.chat_template.replace(
+                """{{ visible_text(m.content) }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}""",
+                """{% set content = visible_text(m.content) %}{{ content }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not content.endswith("/nothink")) else '' -}}""")
+
+        special_vocab.add_to_gguf(self.gguf_writer)
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        if (rope_dim := self.hparams.get("head_dim")) is None:
+            rope_dim = (
+                self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
+            )
+        self.gguf_writer.add_rope_dimension_count(
+            int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))
+        )
+
+        # MoE parameters - Use only routed expert count (shared experts handled separately)
+        if (n_routed_experts := self.hparams.get("n_routed_experts")) is not None:
+            self.gguf_writer.add_expert_count(n_routed_experts)
+        if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
+            self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
+        if (n_shared_experts := self.hparams.get("n_shared_experts")) is not None:
+            self.gguf_writer.add_expert_shared_count(n_shared_experts)
+        if (first_k_dense_replace := self.hparams.get("first_k_dense_replace")) is not None:
+            self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace)
+
+        # Expert gating function (sigmoid for GLM4_MOE)
+        self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
+
+        # Routed scaling factor
+        if (routed_scaling_factor := self.hparams.get("routed_scaling_factor")) is not None:
+            self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
+
+        # Normalise topk probabilities
+        if (norm_topk_prob := self.hparams.get("norm_topk_prob")) is not None:
+            self.gguf_writer.add_expert_weights_norm(norm_topk_prob)
+
+        # NextN/MTP prediction layers
+        if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
+            self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers)
+
+    _experts: list[dict[str, Tensor]] | None = None
+
+    def modify_tensors(
+        self, data_torch: Tensor, name: str, bid: int | None
+    ) -> Iterable[tuple[str, Tensor]]:
+        if name.startswith("model.visual."):  # ignore visual part
+            return []
+        elif name.startswith("model.language_model."):
+            name = name.replace("language_model.", "")  # for multimodal variants
+
+        # Handle main token embedding (but not layer-specific NextN embeddings)
+        if name == "model.embed_tokens.weight" and ".layers." not in name:
+            return [(self.map_tensor_name("token_embd.weight"), data_torch)]
+
+        # Handle routed experts
+        if name.find("mlp.experts") != -1:
+            n_experts = self.hparams["n_routed_experts"]
+            assert bid is not None
+
+            if self._experts is None:
+                self._experts = [{} for _ in range(self.block_count)]
+
+            self._experts[bid][name] = data_torch
+
+            if len(self._experts[bid]) >= n_experts * 3:
+                tensors: list[tuple[str, Tensor]] = []
+
+                # merge the experts into a single 3d tensor
+                for w_name in ["down_proj", "gate_proj", "up_proj"]:
+                    datas: list[Tensor] = []
+
+                    for xid in range(n_experts):
+                        ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
+                        datas.append(self._experts[bid][ename])
+                        del self._experts[bid][ename]
+
+                    data_torch = torch.stack(datas, dim=0)
+
+                    merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
+
+                    new_name = self.map_tensor_name(merged_name)
+                    tensors.append((new_name, data_torch))
+                return tensors
+            else:
+                return []
+
+        if name.endswith("e_score_correction_bias"):
+            name = name.replace("e_score_correction_bias", "e_score_correction.bias")
+
+        new_name = self.map_tensor_name(name)
+
+        return [(new_name, data_torch)]
+
+    def prepare_tensors(self):
+        super().prepare_tensors()
+        if self._experts is not None:
+            # flatten `list[dict[str, Tensor]]` into `list[str]`
+            experts = [k for d in self._experts for k in d.keys()]
+            if len(experts) > 0:
+                raise ValueError(f"Unprocessed experts: {experts}")
+
+
 @ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
 class ChatGLMModel(TextModel):
     model_arch = gguf.MODEL_ARCH.CHATGLM
index 226805f1e1ff82863774929e327f05fc73664836..575e05e193c2ea0a5cb6712cf429cda5ebab76c6 100755 (executable)
@@ -147,6 +147,7 @@ pre_computed_hashes = [
     {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b"},
     {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
     {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
+    {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902"},
     {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
     {"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
     {"name": "hunyuan-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B-Instruct", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"},
index 5707085cb6687a7172c829a95ac1a0edaf99b62a..e2d81dd9891b53c467efdab0f1def3f78af7efe1 100644 (file)
@@ -105,6 +105,7 @@ class Keys:
         EXPERT_WEIGHTS_NORM               = "{arch}.expert_weights_norm"
         EXPERT_GATING_FUNC                = "{arch}.expert_gating_func"
         MOE_EVERY_N_LAYERS                = "{arch}.moe_every_n_layers"
+        NEXTN_PREDICT_LAYERS              = "{arch}.nextn_predict_layers"
         POOLING_TYPE                      = "{arch}.pooling_type"
         LOGIT_SCALE                       = "{arch}.logit_scale"
         DECODER_START_TOKEN_ID            = "{arch}.decoder_start_token_id"
@@ -357,6 +358,7 @@ class MODEL_ARCH(IntEnum):
     DEEPSEEK2        = auto()
     CHATGLM          = auto()
     GLM4             = auto()
+    GLM4_MOE         = auto()
     BITNET           = auto()
     T5               = auto()
     T5ENCODER        = auto()
@@ -614,6 +616,13 @@ class MODEL_TENSOR(IntEnum):
     A_MMPROJ_FC          = auto()
     A_MM_NORM_PRE        = auto()
     A_MM_NORM_MID        = auto()
+    # nextn/mtp
+    NEXTN_EH_PROJ        = auto()
+    NEXTN_EMBED_TOKENS   = auto()
+    NEXTN_ENORM          = auto()
+    NEXTN_HNORM          = auto()
+    NEXTN_SHARED_HEAD_HEAD = auto()
+    NEXTN_SHARED_HEAD_NORM = auto()
 
 
 MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@@ -678,6 +687,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.DEEPSEEK2:        "deepseek2",
     MODEL_ARCH.CHATGLM:          "chatglm",
     MODEL_ARCH.GLM4:             "glm4",
+    MODEL_ARCH.GLM4_MOE:         "glm4moe",
     MODEL_ARCH.BITNET:           "bitnet",
     MODEL_ARCH.T5:               "t5",
     MODEL_ARCH.T5ENCODER:        "t5encoder",
@@ -936,6 +946,13 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.A_MMPROJ_FC:               "mm.a.fc",
     MODEL_TENSOR.A_MM_NORM_PRE:             "mm.a.norm_pre",
     MODEL_TENSOR.A_MM_NORM_MID:             "mm.a.norm_mid",
+    # NextN/MTP
+    MODEL_TENSOR.NEXTN_EH_PROJ:             "blk.{bid}.nextn.eh_proj",
+    MODEL_TENSOR.NEXTN_EMBED_TOKENS:        "blk.{bid}.nextn.embed_tokens",
+    MODEL_TENSOR.NEXTN_ENORM:               "blk.{bid}.nextn.enorm",
+    MODEL_TENSOR.NEXTN_HNORM:               "blk.{bid}.nextn.hnorm",
+    MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD:    "blk.{bid}.nextn.shared_head_head",
+    MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM:    "blk.{bid}.nextn.shared_head_norm",
 }
 
 MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -2124,6 +2141,37 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.ATTN_POST_NORM,
         MODEL_TENSOR.FFN_POST_NORM,
     ],
+    MODEL_ARCH.GLM4_MOE: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_POST_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_Q_NORM,
+        MODEL_TENSOR.ATTN_K_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.FFN_GATE_INP,
+        MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.FFN_DOWN_EXP,
+        MODEL_TENSOR.FFN_UP_EXP,
+        MODEL_TENSOR.FFN_GATE_SHEXP,
+        MODEL_TENSOR.FFN_DOWN_SHEXP,
+        MODEL_TENSOR.FFN_UP_SHEXP,
+        MODEL_TENSOR.FFN_EXP_PROBS_B,
+        # NextN/MTP tensors - preserved but unused
+        MODEL_TENSOR.NEXTN_EH_PROJ,
+        MODEL_TENSOR.NEXTN_EMBED_TOKENS,
+        MODEL_TENSOR.NEXTN_ENORM,
+        MODEL_TENSOR.NEXTN_HNORM,
+        MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
+        MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
+    ],
     MODEL_ARCH.BITNET: [
         MODEL_TENSOR.ATTN_Q,
         MODEL_TENSOR.ATTN_K,
index f4fd64ad822fae7ab85137b5a06d7c715750b976..89249021bc5723d6fce4602202c5593698eb2f93 100644 (file)
@@ -753,6 +753,9 @@ class GGUFWriter:
     def add_moe_every_n_layers(self, value: int) -> None:
         self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)
 
+    def add_nextn_predict_layers(self, count: int) -> None:
+        self.add_uint32(Keys.LLM.NEXTN_PREDICT_LAYERS.format(arch=self.arch), count)
+
     def add_swin_norm(self, value: bool) -> None:
         self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
 
index e6efc93fad82d7da9b42cebd41b84718b61b6e10..dd4f3d52e45ffd9321325ce66bc2d2b064194f46 100644 (file)
@@ -1369,6 +1369,31 @@ class TensorNameMap:
         MODEL_TENSOR.A_MM_NORM_MID: (
             "audio.multi_modal_projector.ln_mid", # ultravox
         ),
+
+        # NextN/MTP tensors for GLM4_MOE
+        MODEL_TENSOR.NEXTN_EH_PROJ: (
+            "model.layers.{bid}.eh_proj",
+        ),
+
+        MODEL_TENSOR.NEXTN_EMBED_TOKENS: (
+            "model.layers.{bid}.embed_tokens",
+        ),
+
+        MODEL_TENSOR.NEXTN_ENORM: (
+            "model.layers.{bid}.enorm",
+        ),
+
+        MODEL_TENSOR.NEXTN_HNORM: (
+            "model.layers.{bid}.hnorm",
+        ),
+
+        MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: (
+            "model.layers.{bid}.shared_head.head",
+        ),
+
+        MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: (
+            "model.layers.{bid}.shared_head.norm",
+        ),
     }
 
     # architecture-specific block mappings
index 35b6386dd0649c2a245d4b9aebfffccbf954abe9..2e8eaa5953b864299850476bd8c785700f6dffd6 100644 (file)
@@ -21,4 +21,5 @@ These templates can be updated with the following commands:
 ./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct                      > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
 ./scripts/get_chat_template.py Qwen/QwQ-32B                                  > models/templates/Qwen-QwQ-32B.jinja
 ./scripts/get_chat_template.py Qwen/Qwen3-0.6B                               > models/templates/Qwen-Qwen3-0.6B.jinja
-```
\ No newline at end of file
+./scripts/get_chat_template.py zai-org/GLM-4.5                               > models/templates/zai-org-GLM-4.5.jinja
+```
index ba7bf9598670f5e585aa52985d393845700766d6..8d669bddcaf528a23c5aa80cfab6bd39e2e2dbc5 100644 (file)
@@ -62,6 +62,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_DEEPSEEK2,        "deepseek2"        },
     { LLM_ARCH_CHATGLM,          "chatglm"          },
     { LLM_ARCH_GLM4,             "glm4"             },
+    { LLM_ARCH_GLM4_MOE,         "glm4moe"          },
     { LLM_ARCH_BITNET,           "bitnet"           },
     { LLM_ARCH_T5,               "t5"               },
     { LLM_ARCH_T5ENCODER,        "t5encoder"        },
@@ -127,6 +128,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_EXPERT_WEIGHTS_NORM,               "%s.expert_weights_norm"               },
     { LLM_KV_EXPERT_GATING_FUNC,                "%s.expert_gating_func"                },
     { LLM_KV_MOE_EVERY_N_LAYERS,                "%s.moe_every_n_layers"                },
+    { LLM_KV_NEXTN_PREDICT_LAYERS,              "%s.nextn_predict_layers"              },
     { LLM_KV_POOLING_TYPE,                      "%s.pooling_type"                      },
     { LLM_KV_LOGIT_SCALE,                       "%s.logit_scale"                       },
     { LLM_KV_DECODER_START_TOKEN_ID,            "%s.decoder_start_token_id"            },
@@ -1391,6 +1393,40 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_POST_NORM,   "blk.%d.post_ffw_norm" },
         },
     },
+    {
+        LLM_ARCH_GLM4_MOE,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
+            { LLM_TENSOR_OUTPUT,             "output" },
+            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_POST_NORM,     "blk.%d.post_attention_norm" },
+            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_Q_NORM,        "blk.%d.attn_q_norm" },
+            { LLM_TENSOR_ATTN_K_NORM,        "blk.%d.attn_k_norm" },
+            { LLM_TENSOR_FFN_GATE,           "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,           "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,             "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
+            { LLM_TENSOR_FFN_GATE_SHEXP,     "blk.%d.ffn_gate_shexp" },
+            { LLM_TENSOR_FFN_DOWN_SHEXP,     "blk.%d.ffn_down_shexp" },
+            { LLM_TENSOR_FFN_UP_SHEXP,       "blk.%d.ffn_up_shexp" },
+            { LLM_TENSOR_FFN_EXP_PROBS_B,    "blk.%d.exp_probs_b" },
+            // NextN/MTP tensors - preserved but unused (in final layer, dynamic layer number)
+            { LLM_TENSOR_NEXTN_EH_PROJ,      "blk.%d.nextn.eh_proj" },
+            { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
+            { LLM_TENSOR_NEXTN_ENORM,        "blk.%d.nextn.enorm" },
+            { LLM_TENSOR_NEXTN_HNORM,        "blk.%d.nextn.hnorm" },
+            { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" },
+            { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" },
+        },
+    },
     {
         LLM_ARCH_BITNET,
         {
@@ -2181,6 +2217,14 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
     {LLM_TENSOR_SHORTCONV_CONV,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
     {LLM_TENSOR_SHORTCONV_INPROJ,           {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
     {LLM_TENSOR_SHORTCONV_OUTPROJ,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    // NextN/MTP tensors are currently ignored (reserved for future MTP support)
+    // These tensors only exist in the last layer(s) and are treated as output tensors
+    {LLM_TENSOR_NEXTN_EH_PROJ,              {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_NEXTN_EMBED_TOKENS,         {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
+    {LLM_TENSOR_NEXTN_ENORM,                {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
+    {LLM_TENSOR_NEXTN_HNORM,                {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
+    {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,     {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,     {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
 };
 
 LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
index 9b8bd65b2322fdc675190fbb0bbe9894f9270425..456eb8d8ccda2122fc992ca9ee94507155a9f6b4 100644 (file)
@@ -66,6 +66,7 @@ enum llm_arch {
     LLM_ARCH_DEEPSEEK2,
     LLM_ARCH_CHATGLM,
     LLM_ARCH_GLM4,
+    LLM_ARCH_GLM4_MOE,
     LLM_ARCH_BITNET,
     LLM_ARCH_T5,
     LLM_ARCH_T5ENCODER,
@@ -131,6 +132,7 @@ enum llm_kv {
     LLM_KV_EXPERT_WEIGHTS_NORM,
     LLM_KV_EXPERT_GATING_FUNC,
     LLM_KV_MOE_EVERY_N_LAYERS,
+    LLM_KV_NEXTN_PREDICT_LAYERS,
     LLM_KV_POOLING_TYPE,
     LLM_KV_LOGIT_SCALE,
     LLM_KV_DECODER_START_TOKEN_ID,
@@ -409,6 +411,12 @@ enum llm_tensor {
     LLM_TENSOR_SHORTCONV_CONV,
     LLM_TENSOR_SHORTCONV_INPROJ,
     LLM_TENSOR_SHORTCONV_OUTPROJ,
+    LLM_TENSOR_NEXTN_EH_PROJ,
+    LLM_TENSOR_NEXTN_EMBED_TOKENS,
+    LLM_TENSOR_NEXTN_ENORM,
+    LLM_TENSOR_NEXTN_HNORM,
+    LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
+    LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
 };
 
 enum llm_tensor_layer {
index 491a26b6346deafcb004fa6648fd8e7ac1ecb51a..9c15e83242849a2e8453907a5a113e7f6c381293 100644 (file)
@@ -749,8 +749,8 @@ ggml_tensor * llm_graph_context::build_ffn(
 
     if (down) {
         cur = build_lora_mm(down, cur);
-        if (arch == LLM_ARCH_GLM4) {
-            // GLM4 seems to have numerical issues with half-precision accumulators
+        if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
+            // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
             ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
         }
     }
@@ -1391,8 +1391,8 @@ ggml_tensor * llm_graph_context::build_attn(
 
     if (wo) {
         cur = build_lora_mm(wo, cur);
-        if (arch == LLM_ARCH_GLM4) {
-            // GLM4 seems to have numerical issues with half-precision accumulators
+        if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
+            // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
             ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
         }
     }
index 8b7e2a1130755ac418a326e1c8a6ab5f16b1f7a7..d60035726802e2efef0e9b4655c21e36f54f88fe 100644 (file)
@@ -73,6 +73,7 @@ struct llama_hparams {
     bool     expert_weights_norm  = false;
     uint32_t expert_gating_func   = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
     uint32_t moe_every_n_layers   = 0;
+    uint32_t nextn_predict_layers = 0;
 
     float f_norm_eps;
     float f_norm_rms_eps;
index e1614d1b8ea5f83bb64bf4e4ae951f0c3d4d48e5..e539142e6b8cd139b313043660ee70a8f1eb1bd9 100644 (file)
@@ -39,6 +39,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
     if (model.arch == LLM_ARCH_GEMMA3N) {
         n_layer_cache = 20;
     }
+    if (model.arch == LLM_ARCH_GLM4_MOE) {
+        // GLM-4.5: Only process up to last layer, skip final NextN layer
+        n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers;
+    }
 
     // create a context for each buffer type
     std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
index 0f52b011b698624e560c767d4ad5d6cd3140343c..c9189f6cb4466421beabc4fb44df9106ecd6c6b8 100644 (file)
@@ -58,8 +58,9 @@ struct llama_model_loader {
         }
     };
 
-    static const int TENSOR_NOT_REQUIRED = 1;
-    static const int TENSOR_DUPLICATED   = 2;
+    static const int TENSOR_NOT_REQUIRED = 1 << 0;
+    static const int TENSOR_DUPLICATED   = 1 << 1;
+    static const int TENSOR_SKIP         = 1 << 2;
 
     int n_kv      = 0;
     int n_tensors = 0;
index 60a615c159a51c6a96ec3baa2deeefcf50d89a3b..44f89003b3917ce2d1eb2d0c615ee836b01e9a9d 100644 (file)
@@ -109,8 +109,10 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_A13B:          return "A13B";
         case LLM_TYPE_21B_A3B:       return "21B.A3B";
         case LLM_TYPE_30B_A3B:       return "30B.A3B";
+        case LLM_TYPE_106B_A12B:     return "106B.A12B";
         case LLM_TYPE_235B_A22B:     return "235B.A22B";
         case LLM_TYPE_300B_A47B:     return "300B.A47B";
+        case LLM_TYPE_355B_A32B:     return "355B.A32B";
         case LLM_TYPE_E2B:           return "E2B";
         case LLM_TYPE_E4B:           return "E4B";
         default:                     return "?B";
@@ -1434,6 +1436,34 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_GLM4_MOE:
+            {
+                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,  hparams.n_ff_exp);
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                // MoE parameters
+                ml.get_key(LLM_KV_EXPERT_COUNT,                hparams.n_expert);
+                ml.get_key(LLM_KV_EXPERT_USED_COUNT,           hparams.n_expert_used);
+                ml.get_key(LLM_KV_EXPERT_SHARED_COUNT,         hparams.n_expert_shared);
+                ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT,   hparams.n_layer_dense_lead, false);
+                ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE,        hparams.expert_weights_scale);
+                ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM,         hparams.expert_weights_norm, false);
+
+                // Expert gating function (GLM-4.5 uses sigmoid)
+                ml.get_key(LLM_KV_EXPERT_GATING_FUNC,          hparams.expert_gating_func, false);
+                if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
+                    hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
+                }
+
+                // NextN/MTP parameters
+                ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS,        hparams.nextn_predict_layers, false);
+
+                switch (hparams.n_layer) {
+                    case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer)
+                    case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer)
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_BITNET:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -1949,6 +1979,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
     const auto TENSOR_DUPLICATED   = llama_model_loader::TENSOR_DUPLICATED;
     const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED;
+    const auto TENSOR_SKIP         = llama_model_loader::TENSOR_SKIP;
 
     // create tensors for the weights
     {
@@ -2004,7 +2035,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
             }
 
             // skip unused tensors
-            if (info.op == GGML_OP_NONE) {
+            if (info.op == GGML_OP_NONE || flags & TENSOR_SKIP) {
                 const size_t nbytes = ggml_nbytes(t_meta);
                 LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", tn.str().c_str(), nbytes);
 
@@ -4427,6 +4458,105 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_post_norm  = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
                     }
                 } break;
+            case LLM_ARCH_GLM4_MOE:
+                {
+                    const int64_t n_expert        = hparams.n_expert;
+                    const int64_t n_expert_used   = hparams.n_expert_used;
+                    const int64_t n_expert_shared = hparams.n_expert_shared;
+
+                    GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers");
+                    GGML_ASSERT(hparams.n_expert_used > 0 && "n_expert_used must be > 0 for GLM4_MOE MoE layers");
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
+                    }
+
+                    // Load ALL tensors including NextN layer to satisfy total tensor count
+                    // but only PROCESS up to last layer (skipping final NextN layer) in forward pass
+                    for (int i = 0; i < n_layer; ++i) {
+                        int flags = 0;
+                        if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
+                            // skip all tensors in the NextN layers
+                            flags |= TENSOR_SKIP;
+                        }
+
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, flags);
+
+                        // GLM-style attention with bias terms
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags);
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, flags);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, flags);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, flags);
+
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags);
+
+                        // K/Q norm tensors (optional for GLM-4.5 355B variant)
+                        layer.attn_q_norm = create_tensor(
+                            tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags);
+                        layer.attn_k_norm = create_tensor(
+                            tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags);
+
+                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags);
+
+                        // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead
+                        // GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE
+                        const bool use_moe = (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead);
+
+                        if (use_moe) {
+                            // MoE layers
+                            layer.ffn_gate_inp =
+                                create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags);
+                            layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, flags);
+
+                            // MoE branch
+                            const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
+
+                            layer.ffn_gate_exps = create_tensor(
+                                tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags);
+                            layer.ffn_down_exps = create_tensor(
+                                tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, flags);
+                            layer.ffn_up_exps = create_tensor(
+                                tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags);
+
+                            // Shared expert
+                            if (n_expert_shared > 0) {
+                                const int64_t n_ff_shexp = n_ff_exp * n_expert_shared;
+                                layer.ffn_gate_shexp = create_tensor(
+                                    tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags);
+                                layer.ffn_down_shexp = create_tensor(
+                                    tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, flags);
+                                layer.ffn_up_shexp = create_tensor(
+                                    tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags);
+                            }
+                        } else {
+                            // Dense layers (first k layers) - GLM uses separate gate/up projections
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, flags);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, flags);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), { n_embd, n_ff }, flags);
+                        }
+
+                        // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers
+                        if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
+                            layer.nextn.eh_proj          = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
+                            layer.nextn.embed_tokens     = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags);
+                            layer.nextn.enorm            = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
+                            layer.nextn.hnorm            = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);
+                            layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags);
+                            layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags);
+                        }
+                    }
+                }
+                break;
             case LLM_ARCH_NEMOTRON:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -13564,6 +13694,169 @@ struct llm_build_glm4 : public llm_graph_context {
     }
 };
 
+struct llm_build_glm4_moe : public llm_graph_context {
+    llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified();
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        // Only process up to last layer (skip final NextN layer)
+        // Final layer tensors are loaded but not processed in forward pass
+        const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
+        for (int il = 0; il < n_transformer_layers; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // Pre-attention norm
+            cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                }
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                }
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                }
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                // Apply Q/K norm if available (GLM-4.5 355B variant)
+                if (model.layers[il].attn_q_norm) {
+                    Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+                    cb(Qcur, "Qcur_normed", il);
+                }
+                if (model.layers[il].attn_k_norm) {
+                    Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+                    cb(Kcur, "Kcur_normed", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_transformer_layers - 1 && inp_out_ids) {
+                cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // Post-attention norm
+            cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "post_attn_norm", il);
+
+            // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense)
+            if (static_cast<uint32_t>(il) < hparams.n_layer_dense_lead) {
+                // Dense FFN layer
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            } else {
+                // MoE layer with shared experts
+                const int64_t n_expert      = hparams.n_expert;
+                const int64_t n_expert_used = hparams.n_expert_used;
+
+                // Process routed experts using existing MoE infrastructure
+                ggml_tensor * routed_out = build_moe_ffn(cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        model.layers[il].ffn_exp_probs_b,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, hparams.expert_weights_norm,
+                        true, hparams.expert_weights_scale,
+                        (llama_expert_gating_func_type) hparams.expert_gating_func,
+                        il);
+                cb(routed_out, "ffn_moe_out", il);
+
+                // Process shared expert on original input
+                ggml_tensor * shared_out = build_ffn(cur,
+                        model.layers[il].ffn_up_shexp,   NULL, NULL,
+                        model.layers[il].ffn_gate_shexp, NULL, NULL,
+                        model.layers[il].ffn_down_shexp, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(shared_out, "ffn_shexp_out", il);
+
+                // Final output: routed_output + shared_output
+                cur = ggml_add(ctx0, routed_out, shared_out);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+        cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
 struct llm_build_nemotron : public llm_graph_context {
     llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -17877,6 +18170,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             {
                 llm = std::make_unique<llm_build_glm4>(*this, params);
             } break;
+        case LLM_ARCH_GLM4_MOE:
+            {
+                llm = std::make_unique<llm_build_glm4_moe>(*this, params);
+            } break;
         case LLM_ARCH_BITNET:
             {
                 llm = std::make_unique<llm_build_bitnet>(*this, params);
@@ -18208,6 +18505,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_HUNYUAN_DENSE:
         case LLM_ARCH_LFM2:
         case LLM_ARCH_SMALLTHINKER:
+        case LLM_ARCH_GLM4_MOE:
             return LLAMA_ROPE_TYPE_NEOX;
 
         case LLM_ARCH_QWEN2VL:
index 094e23808a81392674c1122369cdef9c3651dc57..bdb81cecd89d03c5c14cb1ef1452027213fcc27b 100644 (file)
@@ -101,8 +101,10 @@ enum llm_type {
     LLM_TYPE_A13B,
     LLM_TYPE_21B_A3B, // Ernie MoE small
     LLM_TYPE_30B_A3B,
+    LLM_TYPE_106B_A12B, // GLM-4.5-Air
     LLM_TYPE_235B_A22B,
     LLM_TYPE_300B_A47B, // Ernie MoE big
+    LLM_TYPE_355B_A32B, // GLM-4.5
     LLM_TYPE_E2B,
     LLM_TYPE_E4B,
 };
@@ -166,6 +168,15 @@ struct llama_layer_shortconv {
     struct ggml_tensor * out_proj = nullptr;
 };
 
+struct llama_layer_nextn {
+    struct ggml_tensor * eh_proj          = nullptr;
+    struct ggml_tensor * embed_tokens     = nullptr;
+    struct ggml_tensor * enorm            = nullptr;
+    struct ggml_tensor * hnorm            = nullptr;
+    struct ggml_tensor * shared_head_head = nullptr;
+    struct ggml_tensor * shared_head_norm = nullptr;
+};
+
 struct llama_layer {
     // normalization
     struct ggml_tensor * attn_norm       = nullptr;
@@ -354,6 +365,8 @@ struct llama_layer {
     struct llama_layer_convnext convnext;
 
     struct llama_layer_shortconv shortconv;
+
+    struct llama_layer_nextn nextn;
 };
 
 struct llama_model {
index 959c86a14745fe5c02873a74a9d488379e3053f4..3f43fc556f857efb34b0490ba8d0d49cd7b511f7 100644 (file)
@@ -2191,6 +2191,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|fim▁begin|>" // DeepSeek
                         || t.first == "<PRE>"
                         || t.first == "▁<PRE>"          // CodeLlama
+                        || t.first == "<|code_prefix|>" // GLM-4.5
                         ) {
                     special_fim_pre_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2210,6 +2211,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|fim▁hole|>" // DeepSeek
                         || t.first == "<SUF>"
                         || t.first == "▁<SUF>"         // CodeLlama
+                        || t.first == "<|code_suffix|>" // GLM-4.5
                         ) {
                     special_fim_suf_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2229,6 +2231,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|fim▁end|>"  // DeepSeek
                         || t.first == "<MID>"
                         || t.first == "▁<MID>"         // CodeLlama
+                        || t.first == "<|code_middle|>" // GLM-4.5
                         ) {
                     special_fim_mid_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {