if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
res = "minerva-7b"
+ if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
+ # ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
+ res = "hunyuan"
if res is None:
logger.warning("\n")
super().set_gguf_parameters()
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
+
+@ModelBase.register("HunYuanMoEV1ForCausalLM")
+class HunYuanMoEModel(TextModel):
+ model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # For handling tied embeddings
+ self._tok_embd = None
+
+ def set_vocab(self):
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
+
+ # 1. Get the pre-tokenizer identifier hash
+ tokpre = self.get_vocab_base_pre(tokenizer)
+
+ # 2. Reverse-engineer the merges list from mergeable_ranks
+ merges = []
+ vocab = {}
+ mergeable_ranks = tokenizer.mergeable_ranks
+ for token, rank in mergeable_ranks.items():
+ vocab[QwenModel.token_bytes_to_string(token)] = rank
+ if len(token) == 1:
+ continue
+ merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
+ if len(merged) == 2: # todo this is an assert in Qwen, why?
+ merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
+
+ # 3. Generate the tokens and toktypes lists
+ vocab_size = self.hparams["vocab_size"]
+ assert tokenizer.vocab_size == vocab_size
+ special_tokens = tokenizer.special_tokens
+ reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
+ tokens: list[str] = []
+ toktypes: list[int] = []
+ for i in range(vocab_size):
+ if i not in reverse_vocab:
+ tokens.append(f"[PAD{i}]")
+ toktypes.append(gguf.TokenType.UNUSED)
+ else:
+ token = reverse_vocab[i]
+ tokens.append(token)
+ if i in special_tokens.values():
+ toktypes.append(gguf.TokenType.CONTROL)
+ else:
+ toktypes.append(gguf.TokenType.NORMAL)
+
+ # 4. Write all vocab-related fields to the GGUF writer
+ self.gguf_writer.add_tokenizer_model("gpt2")
+ self.gguf_writer.add_tokenizer_pre(tokpre)
+ self.gguf_writer.add_token_list(tokens)
+ self.gguf_writer.add_token_types(toktypes)
+ self.gguf_writer.add_token_merges(merges)
+
+ # 5. Add special tokens and chat templates
+ special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
+ special_vocab.add_to_gguf(self.gguf_writer)
+ # FIX for BOS token: Overwrite incorrect id read from config.json
+ self.gguf_writer.add_bos_token_id(127959) # <|bos|>
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ hparams = self.hparams
+
+ self.gguf_writer.add_expert_count(hparams["num_experts"])
+ self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"])
+
+ moe_intermediate_size = hparams["moe_intermediate_size"]
+ assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size)
+ self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0])
+
+ moe_topk = hparams["moe_topk"]
+ assert all(topk == moe_topk[0] for topk in moe_topk)
+ self.gguf_writer.add_expert_used_count(moe_topk[0])
+
+ moe_shared_expert = hparams["num_shared_expert"]
+ assert all(n == moe_shared_expert[0] for n in moe_shared_expert)
+ self.gguf_writer.add_expert_shared_count(moe_shared_expert[0])
+
+ # Rope
+ rope_scaling = hparams.get("rope_scaling", {})
+ if rope_scaling.get("type") == "dynamic":
+ # HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
+ # 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf)
+ alpha = rope_scaling.get("alpha", 1000)
+ base = hparams.get("rope_theta", 10000.0)
+ dim = (hparams["hidden_size"] // hparams["num_attention_heads"]) # 128
+ scaled_base = base * (alpha ** (dim / (dim - 2))) # 10000 * (1000 ** (128 / 126)) = 11158839.9251
+ self.gguf_writer.add_rope_freq_base(scaled_base)
+ self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
+ self.gguf_writer.add_rope_scaling_factor(1)
+ # There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k
+ self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length
+ self.gguf_writer.add_context_length(256 * 1024) # 256k context length
+
+ # if any of our assumptions about the values are wrong, something has changed and this may need to be updated
+ assert alpha == 1000 and base == 10000.0 and dim == 128 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \
+ "HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually"
+
+ _experts: list[dict[str, Tensor]] | None = None
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ if name == "model.embed_tokens.weight":
+ self._tok_embd = data_torch.clone()
+
+ if name == "lm_head.weight":
+ if self.hparams.get("tie_word_embeddings", False):
+ logger.info("Skipping tied output layer 'lm_head.weight'")
+ return []
+
+ if name.find("mlp.experts") != -1:
+ n_experts = self.hparams["num_experts"]
+ assert bid is not None
+
+ if self._experts is None:
+ self._experts = [{} for _ in range(self.block_count)]
+
+ self._experts[bid][name] = data_torch
+
+ if len(self._experts[bid]) >= n_experts * 3:
+ # merge the experts into a single 3d tensor
+ tensors: list[tuple[str, Tensor]] = []
+ for w_name in ["down_proj", "gate_proj", "up_proj"]:
+ datas: list[Tensor] = []
+
+ for xid in range(n_experts):
+ ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
+ datas.append(self._experts[bid][ename])
+ del self._experts[bid][ename]
+
+ data_torch = torch.stack(datas, dim=0)
+ merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
+ new_name = self.map_tensor_name(merged_name)
+ tensors.append((new_name, data_torch))
+
+ return tensors
+ else:
+ return []
+
+ return [(self.map_tensor_name(name), data_torch)]
+
+ def prepare_tensors(self):
+ super().prepare_tensors()
+ if self._experts is not None:
+ experts = [k for d in self._experts for k in d.keys()]
+ if len(experts) > 0:
+ raise ValueError(f"Unprocessed experts: {experts}")
+
###### CONVERSION LOGIC ######
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
+ {"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
]
DOTS1 = auto()
ARCEE = auto()
ERNIE4_5 = auto()
+ HUNYUAN_MOE = auto()
class VISION_PROJECTOR_TYPE(IntEnum):
MODEL_ARCH.DOTS1: "dots1",
MODEL_ARCH.ARCEE: "arcee",
MODEL_ARCH.ERNIE4_5: "ernie4_5",
+ MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
}
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
+ MODEL_ARCH.HUNYUAN_MOE: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ MODEL_TENSOR.FFN_GATE_SHEXP,
+ MODEL_TENSOR.FFN_DOWN_SHEXP,
+ MODEL_TENSOR.FFN_UP_SHEXP,
+ ],
# TODO
}
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
"model.layers.{bid}.feed_forward.router", # llama4
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
+ "model.layers.{bid}.mlp.gate.wg", # hunyuan
),
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
+ "model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
),
# AWQ-activation gate
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4
+ "model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
),
# Feed-forward down
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4
"model.layers.{bid}.shared_mlp.output_linear", # granitemoe
+ "model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
),
MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
+ "model.layers.{bid}.self_attn.query_layernorm", # hunyuan
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
MODEL_TENSOR.ATTN_K_NORM: (
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
+ "model.layers.{bid}.self_attn.key_layernorm", # hunyuan
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
+ LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
};
enum llama_rope_type {
{ LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" },
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
+ { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
+ {
+ LLM_ARCH_HUNYUAN_MOE,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
+ { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
+ { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
+ },
+ },
{
LLM_ARCH_UNKNOWN,
{
LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE,
LLM_ARCH_ERNIE4_5,
+ LLM_ARCH_HUNYUAN_MOE,
LLM_ARCH_UNKNOWN,
};
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
+ { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
};
llm_chat_template llm_chat_template_from_str(const std::string & name) {
return LLM_CHAT_TEMPLATE_LLAMA4;
} else if (tmpl_contains("<|endofuserprompt|>")) {
return LLM_CHAT_TEMPLATE_DOTS1;
+ } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
+ return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
if (add_ass) {
ss << "<|response|>";
}
+ } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) {
+ // tencent/Hunyuan-A13B-Instruct
+ for (auto message : chat) {
+ std::string role(message->role);
+ if (role == "system") {
+ ss << "<|startoftext|>" << message->content << "<|extra_4|>";
+ } else if (role == "assistant") {
+ ss << "<|startoftext|>" << message->content << "<|eos|>";
+ } else {
+ ss << "<|startoftext|>" << message->content << "<|extra_0|>";
+ }
+ }
+ if (add_ass) {
+ ss << "<|startoftext|>";
+ }
} else {
// template not supported
return -1;
LLM_CHAT_TEMPLATE_LLAMA4,
LLM_CHAT_TEMPLATE_SMOLVLM,
LLM_CHAT_TEMPLATE_DOTS1,
+ LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
LLM_CHAT_TEMPLATE_UNKNOWN,
};
case LLM_TYPE_57B_A14B: return "57B.A14B";
case LLM_TYPE_17B_16E: return "17Bx16E (Scout)";
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
+ case LLM_TYPE_A13B: return "A13B";
case LLM_TYPE_30B_A3B: return "30B.A3B";
case LLM_TYPE_235B_A22B: return "235B.A22B";
case LLM_TYPE_E2B: return "E2B";
default: type = LLM_TYPE_UNKNOWN;
}
} break;
+ case LLM_ARCH_HUNYUAN_MOE:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
+ ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
+
+ switch (hparams.n_layer) {
+ case 32: type = LLM_TYPE_A13B; break;
+ default: type = LLM_TYPE_UNKNOWN;
+ }
+ } break;
default: throw std::runtime_error("unsupported model architecture");
}
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
+ case LLM_ARCH_HUNYUAN_MOE:
+ {
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ // output
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+ // if output is NULL, init from the input tok embed
+ if (output == NULL) {
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = layers[i];
+
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
+
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+ layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
+ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
+
+ layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
+ layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
+ layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0);
+ }
+ } break;
default:
throw std::runtime_error("unknown architecture");
}
}
};
+struct llm_build_hunyuan_moe : public llm_graph_context {
+ llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+ ggml_tensor * cur;
+ ggml_tensor * inpL;
+
+ inpL = build_inp_embd(model.tok_embd);
+
+ // inp_pos - contains the positions
+ ggml_tensor * inp_pos = build_inp_pos();
+
+ auto * inp_attn = build_attn_inp_kv_unified();
+
+ const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
+
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+ for (int il = 0; il < n_layer; ++il) {
+ ggml_tensor * inpSA = inpL;
+
+ // norm
+ cur = build_norm(inpL,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ // rope freq factors for llama3; may return nullptr for llama2 and other models
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
+
+ // compute Q and K and RoPE them
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+ if (model.layers[il].bq) {
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+ cb(Qcur, "Qcur", il);
+ }
+
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+ if (model.layers[il].bk) {
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+ cb(Kcur, "Kcur", il);
+ }
+
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+ if (model.layers[il].bv) {
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+ cb(Vcur, "Vcur", il);
+ }
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+ Qcur = ggml_rope_ext(
+ ctx0, Qcur, inp_pos, rope_factors,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
+
+ Kcur = ggml_rope_ext(
+ ctx0, Kcur, inp_pos, rope_factors,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ Kcur = build_norm(Kcur,
+ model.layers[il].attn_k_norm, nullptr,
+ LLM_NORM_RMS, il);
+ cb(Kcur, "Kcur_norm", il);
+
+ Qcur = build_norm(Qcur,
+ model.layers[il].attn_q_norm, nullptr,
+ LLM_NORM_RMS, il);
+ cb(Qcur, "Qcur_norm", il);
+
+ cur = build_attn(inp_attn, gf,
+ model.layers[il].wo, model.layers[il].bo,
+ Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+ cb(cur, "attn_out", il);
+ }
+
+ if (il == n_layer - 1 && inp_out_ids) {
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ cur = build_norm(ffn_inp,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "ffn_norm", il);
+
+ // feed-forward network (non-MoE)
+ ggml_tensor * cur_mlp = build_ffn(cur,
+ model.layers[il].ffn_up_shexp, NULL, NULL,
+ model.layers[il].ffn_gate_shexp, NULL, NULL,
+ model.layers[il].ffn_down_shexp, NULL, NULL,
+ NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
+ cb(cur_mlp, "ffn_mlp", il);
+
+ // MoE branch
+ ggml_tensor * cur_moe = build_moe_ffn(cur,
+ model.layers[il].ffn_gate_inp,
+ model.layers[il].ffn_up_exps,
+ model.layers[il].ffn_gate_exps,
+ model.layers[il].ffn_down_exps,
+ nullptr,
+ n_expert, n_expert_used,
+ LLM_FFN_SILU,
+ true, // norm_topk_prob
+ false,
+ 0.0,
+ LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+ il);
+ cb(cur_moe, "ffn_moe_out", il);
+
+ ggml_tensor * ffn_out = ggml_add(ctx0, cur_moe, cur_mlp);
+ cb(ffn_out, "ffn_out", il);
+
+ cur = ggml_add(ctx0, ffn_out, ffn_inp);
+
+ cur = build_cvec(cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = build_norm(cur,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, -1);
+
+ cb(cur, "result_norm", -1);
+ res->t_embd = cur;
+
+ // lm_head
+ cur = build_lora_mm(model.output, cur);
+
+ cb(cur, "result_output", -1);
+ res->t_logits = cur;
+
+ ggml_build_forward_expand(gf, cur);
+ }
+};
+
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
llama_memory_i * res;
{
llm = std::make_unique<llm_build_ernie4_5>(*this, params, gf);
} break;
+ case LLM_ARCH_HUNYUAN_MOE:
+ {
+ llm = std::make_unique<llm_build_hunyuan_moe>(*this, params, gf);
+ } break;
default:
GGML_ABORT("fatal error");
}
case LLM_ARCH_EXAONE:
case LLM_ARCH_MINICPM3:
case LLM_ARCH_DOTS1:
+ case LLM_ARCH_HUNYUAN_MOE:
return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_QWEN2VL:
LLM_TYPE_57B_A14B,
LLM_TYPE_17B_16E, // llama4 Scout
LLM_TYPE_17B_128E, // llama4 Maverick
+ LLM_TYPE_A13B,
LLM_TYPE_30B_A3B,
LLM_TYPE_235B_A22B,
LLM_TYPE_E2B,
break;
case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
case LLAMA_VOCAB_PRE_TYPE_QWEN2:
+ case LLAMA_VOCAB_PRE_TYPE_HUNYUAN:
regex_exprs = {
// original regex from tokenizer.json
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
tokenizer_pre == "seed-coder") {
pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER;
clean_spaces = false;
+ } else if (
+ tokenizer_pre == "hunyuan") {
+ pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN;
+ clean_spaces = false;
} else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
}