From: Xuan-Son Nguyen Date: Thu, 15 May 2025 15:40:07 +0000 (+0200) Subject: convert : fix conversion for llama 4 (#13567) X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=c531edfa34fec8074d0b424396cdd450df23b612;p=pkg%2Fggml%2Fsources%2Fllama.cpp convert : fix conversion for llama 4 (#13567) --- diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 68b5e879..b5402cbe 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2069,6 +2069,9 @@ class Llama4Model(LlamaModel): self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + if name.startswith("language_model."): + name = name.replace("language_model.", "") + # split the gate_up into gate and up if "gate_up_proj" in name: name_up = name.replace("gate_up_proj", "up_proj.weight") diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 2629b3c1..dadb4fd9 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -68,7 +68,7 @@ class TensorNameMap: "output_layer", # chatglm "head", # rwkv "head.out", # wavtokenizer - "language_model.lm_head", # llama4 + "lm_head", # llama4 ), # Output norm @@ -91,7 +91,7 @@ class TensorNameMap: "rwkv.ln_out", # rwkv6 "model.ln_out", # rwkv7 "backbone.final_layer_norm", # wavtokenizer - "language_model.model.norm", # llama4 + "model.norm", # llama4 ), # Rope frequencies @@ -133,7 +133,7 @@ class TensorNameMap: "transformer.layers.{bid}.attn_norm", # openelm "rwkv.blocks.{bid}.ln1", # rwkv6 "model.layers.{bid}.ln1", # rwkv7 - "language_model.model.layers.{bid}.input_layernorm", # llama4 + "model.layers.{bid}.input_layernorm", # llama4 ), # Attention norm 2 @@ -173,7 +173,7 @@ class TensorNameMap: "model.layers.{bid}.attention.wq", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok "transformer.h.{bid}.attn.attention.q_proj", # exaone - "language_model.model.layers.{bid}.self_attn.q_proj", # llama4 + "model.layers.{bid}.self_attn.q_proj", # llama4 ), # Attention key @@ -188,7 +188,7 @@ class TensorNameMap: "model.layers.{bid}.attention.wk", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok "transformer.h.{bid}.attn.attention.k_proj", # exaone - "language_model.model.layers.{bid}.self_attn.k_proj", # llama4 + "model.layers.{bid}.self_attn.k_proj", # llama4 ), # Attention value @@ -202,7 +202,7 @@ class TensorNameMap: "model.layers.{bid}.attention.wv", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok "transformer.h.{bid}.attn.attention.v_proj", # exaone - "language_model.model.layers.{bid}.self_attn.v_proj", # llama4 + "model.layers.{bid}.self_attn.v_proj", # llama4 ), # Attention output @@ -229,7 +229,7 @@ class TensorNameMap: "encoder.layers.{bid}.self_attention.dense", # chatglm "transformer.layers.{bid}.attn.out_proj", # openelm "transformer.h.{bid}.attn.attention.out_proj", # exaone - "language_model.model.layers.{bid}.self_attn.o_proj", # llama4 + "model.layers.{bid}.self_attn.o_proj", # llama4 ), # Attention output norm @@ -268,7 +268,7 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.rms_norm_2", # Grok "encoder.layers.{bid}.post_attention_layernorm", # chatglm "transformer.layers.{bid}.ffn_norm", # openelm - "language_model.model.layers.{bid}.post_attention_layernorm", # llama4 + "model.layers.{bid}.post_attention_layernorm", # llama4 ), # Post feed-forward norm @@ -289,7 +289,7 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.router", # Grok "transformer.blocks.{bid}.ffn.router.layer", # dbrx "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe - "language_model.model.layers.{bid}.feed_forward.router", # llama4 + "model.layers.{bid}.feed_forward.router", # llama4 "encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe ), @@ -329,7 +329,7 @@ class TensorNameMap: "model.layers.{bid}.residual_mlp.w3", # arctic "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm "transformer.h.{bid}.mlp.c_fc_1", # exaone - "language_model.model.layers.{bid}.feed_forward.up_proj", # llama4 + "model.layers.{bid}.feed_forward.up_proj", # llama4 ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -338,14 +338,14 @@ class TensorNameMap: "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged) - "language_model.model.layers.{bid}.feed_forward.experts.up_proj", # llama4 + "model.layers.{bid}.feed_forward.experts.up_proj", # llama4 "encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe ), MODEL_TENSOR.FFN_UP_SHEXP: ( - "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe - "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2 - "language_model.model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4 + "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2 + "model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4 ), # AWQ-activation gate @@ -366,22 +366,22 @@ class TensorNameMap: "transformer.h.{bid}.mlp.linear_1", # refact "model.layers.{bid}.residual_mlp.w1", # arctic "transformer.h.{bid}.mlp.c_fc_0", # exaone - "language_model.model.layers.{bid}.feed_forward.gate_proj", # llama4 + "model.layers.{bid}.feed_forward.gate_proj", # llama4 ), MODEL_TENSOR.FFN_GATE_EXP: ( - "layers.{bid}.feed_forward.experts.w1", # mixtral (merged) - "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged) - "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx - "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) - "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged) - "language_model.model.layers.{bid}.feed_forward.experts.gate_proj", # llama4 + "layers.{bid}.feed_forward.experts.w1", # mixtral (merged) + "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx + "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) + "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged) + "model.layers.{bid}.feed_forward.experts.gate_proj", # llama4 ), MODEL_TENSOR.FFN_GATE_SHEXP: ( - "model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe - "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2 - "language_model.model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4 + "model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2 + "model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4 ), # Feed-forward down @@ -410,7 +410,7 @@ class TensorNameMap: "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm "model.layers.h.{bid}.mlp.c_proj", # exaone - "language_model.model.layers.{bid}.feed_forward.down_proj", # llama4 + "model.layers.{bid}.feed_forward.down_proj", # llama4 ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -420,15 +420,15 @@ class TensorNameMap: "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) "model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe "model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged) - "language_model.model.layers.{bid}.feed_forward.experts.down_proj", # llama4 + "model.layers.{bid}.feed_forward.experts.down_proj", # llama4 "encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe ), MODEL_TENSOR.FFN_DOWN_SHEXP: ( - "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe - "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2 - "language_model.model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4 - "model.layers.{bid}.shared_mlp.output_linear", # granitemoe + "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2 + "model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4 + "model.layers.{bid}.shared_mlp.output_linear", # granitemoe ), MODEL_TENSOR.ATTN_Q_NORM: (