if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
import gguf
+from gguf.vocab import MistralTokenizerType, MistralVocab
+from mistral_common.tokens.tokenizers.base import TokenizerVersion
+from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN, DATASET_STD
+from mistral_common.tokens.tokenizers.tekken import Tekkenizer
+from mistral_common.tokens.tokenizers.sentencepiece import (
+ SentencePieceTokenizer,
+)
+
logger = logging.getLogger("hf-to-gguf")
block_count: int
tensor_map: gguf.TensorNameMap
+ is_mistral_format: bool = False
+
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
self.tensor_names = set(name for name in remote_tensors.keys())
- for name, remote_tensor in gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id).items():
+ for name, remote_tensor in remote_tensors.items():
yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))
self.get_tensors = get_remote_tensors
else:
- self.part_names = ModelBase.get_model_part_names(self.dir_model, "model", ".safetensors")
+ prefix = "model" if not self.is_mistral_format else "consolidated"
+ self.part_names = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
self.is_safetensors = len(self.part_names) > 0
if not self.is_safetensors:
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
- self.hparams = ModelBase.load_hparams(self.dir_model) if hparams is None else hparams
+ self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
self.tensor_names = None
self.metadata_override = metadata_override
self.model_name = model_name
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
tensor_names_from_parts: set[str] = set()
- index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
- index_name += ".index.json"
- index_file = self.dir_model / index_name
-
- if index_file.is_file():
- self.tensor_names = set()
- logger.info(f"gguf: loading model weight map from '{index_name}'")
- with open(index_file, "r", encoding="utf-8") as f:
- index: dict[str, Any] = json.load(f)
- weight_map = index.get("weight_map")
- if weight_map is None or not isinstance(weight_map, dict):
- raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
- self.tensor_names.update(weight_map.keys())
+ if not self.is_mistral_format:
+ index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
+ index_name += ".index.json"
+ index_file = self.dir_model / index_name
+
+ if index_file.is_file():
+ self.tensor_names = set()
+ logger.info(f"gguf: loading model weight map from '{index_name}'")
+ with open(index_file, "r", encoding="utf-8") as f:
+ index: dict[str, Any] = json.load(f)
+ weight_map = index.get("weight_map")
+ if weight_map is None or not isinstance(weight_map, dict):
+ raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
+ self.tensor_names.update(weight_map.keys())
+ else:
+ self.tensor_names = tensor_names_from_parts
+ weight_map = {}
else:
self.tensor_names = tensor_names_from_parts
weight_map = {}
return part_names
@staticmethod
- def load_hparams(dir_model: Path):
+ def load_hparams(dir_model: Path, is_mistral_format: bool):
+ if is_mistral_format:
+ with open(dir_model / "params.json", "r", encoding="utf-8") as f:
+ config = json.load(f)
+ return config
+
try:
# for security reason, we don't allow loading remote code by default
# if a model need remote code, we will fallback to config.json
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.hf_arch = get_model_architecture(self.hparams, self.model_type)
+ if not self.is_mistral_format:
+ self.hf_arch = get_model_architecture(self.hparams, self.model_type)
+ else:
+ self.hf_arch = ""
if "text_config" in self.hparams:
# move the text_config to the root level
self.gguf_writer.add_head_count(n_head)
logger.info(f"gguf: head count = {n_head}")
- if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
+ if (n_head_kv := self.find_hparam(["num_key_value_heads", "n_kv_heads"], optional=True)) is not None:
self.gguf_writer.add_head_count_kv(n_head_kv)
logger.info(f"gguf: key-value head count = {n_head_kv}")
if (rope_theta := self.hparams.get("rope_theta")) is not None:
self.gguf_writer.add_rope_freq_base(rope_theta)
logger.info(f"gguf: rope theta = {rope_theta}")
- if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
+ if (f_rms_eps := self.find_hparam(["rms_norm_eps", "norm_eps"])) is not None:
self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
logger.info(f"gguf: rms norm epsilon = {f_rms_eps}")
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
raise TypeError("MmprojModel must be subclassed with model_arch = gguf.MODEL_ARCH.MMPROJ")
# get n_embd of the text model
- if "text_config" not in self.hparams:
- self.hparams["text_config"] = {}
- if "audio_config" not in self.hparams:
- self.hparams["audio_config"] = {}
- text_config = {**self.hparams, **self.hparams["text_config"]}
- self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
+ if not self.is_mistral_format:
+ if "text_config" not in self.hparams:
+ self.hparams["text_config"] = {}
+ if "audio_config" not in self.hparams:
+ self.hparams["audio_config"] = {}
+ text_config = {**self.hparams, **self.hparams["text_config"]}
+ self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
+ else:
+ text_config = {
+ k: v for k, v in self.hparams.items() if k not in ["vision_encoder", "audio_encoder"]
+ }
+ self.n_embd_text = text_config.get("hidden_dim", 0)
+
assert self.n_embd_text > 0, "n_embd not found in hparams"
# move vision config to the top level, while preserving the original hparams in global_config
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
# load preprocessor config
- with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
- self.preprocessor_config = json.load(f)
+ if not self.is_mistral_format:
+ with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
+ self.preprocessor_config = json.load(f)
def get_vision_config(self) -> dict[str, Any] | None:
- return self.global_config.get("vision_config")
+ config_name = "vision_config" if not self.is_mistral_format else "vision_encoder"
+ return self.global_config.get(config_name)
def get_audio_config(self) -> dict[str, Any] | None:
return self.global_config.get("audio_config")
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"]))
# preprocessor config
- self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"])
- self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"])
+ image_mean = DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
+ image_std = DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"]
+
+ self.gguf_writer.add_vision_image_mean(image_mean)
+ self.gguf_writer.add_vision_image_std(image_std)
if self.has_audio_encoder:
self.gguf_writer.add_clip_has_audio_encoder(True)
if self.hf_arch == "VLlama3ForCausalLM":
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
- def set_vocab(self):
- path_tekken_json = self.dir_model / "tekken.json"
- path_tokenizer_json = self.dir_model / "tokenizer.json"
- if path_tekken_json.is_file() and not path_tokenizer_json.is_file():
- return self.set_vocab_tekken()
-
- try:
- self._set_vocab_sentencepiece()
- except FileNotFoundError:
- try:
- self._set_vocab_llama_hf()
- except (FileNotFoundError, TypeError):
- # Llama 3
- self._set_vocab_gpt2()
-
- # Apply to CodeLlama only (and ignore for Llama 3 with a vocab size of 128256)
- if self.hparams.get("vocab_size", 32000) == 32016:
- special_vocab = gguf.SpecialVocab(
- self.dir_model, load_merges=False,
- special_token_types = ['prefix', 'suffix', 'middle', 'eot']
- )
- special_vocab._set_special_token("prefix", 32007)
- special_vocab._set_special_token("suffix", 32008)
- special_vocab._set_special_token("middle", 32009)
- special_vocab._set_special_token("eot", 32010)
- special_vocab.add_to_gguf(self.gguf_writer)
-
- tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
- if tokenizer_config_file.is_file():
- with open(tokenizer_config_file, "r", encoding="utf-8") as f:
- tokenizer_config_json = json.load(f)
- if "add_prefix_space" in tokenizer_config_json:
- self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
-
- # Apply to granite small models only
- if self.hparams.get("vocab_size", 32000) == 49152:
- self.gguf_writer.add_add_bos_token(False)
+ def _set_vocab_mistral(self):
+ vocab = MistralVocab(self.dir_model)
+ logger.info(
+ f"Converting tokenizer {vocab.tokenizer_type} of size {vocab.vocab_size}."
+ )
- def set_vocab_tekken(self):
- vocab = gguf.vocab.MistralVocab(self.dir_model)
self.gguf_writer.add_tokenizer_model(vocab.gguf_tokenizer_model)
tokens = []
f"token count ({len(tokens)}) != vocab size ({vocab.vocab_size})"
)
- if vocab.tokenizer_type == gguf.vocab.MistralTokenizerType.tekken:
+ if vocab.tokenizer_type == MistralTokenizerType.tekken:
self.gguf_writer.add_tokenizer_pre("tekken")
self.gguf_writer.add_token_merges(
vocab.extract_vocab_merges_from_model()
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(False)
- script_dir = Path(__file__).parent
- template_path = script_dir / "models/templates/unsloth-mistral-Devstral-Small-2507.jinja"
- with open(template_path, "r", encoding="utf-8") as f:
- template = f.read()
- self.gguf_writer.add_chat_template(template)
+ template_dir = Path(__file__).parent / "models/templates/"
+
+ template = MistralModel.get_community_chat_template(vocab, template_dir)
+ self.gguf_writer.add_chat_template(template)
+
+ def set_vocab(self):
+ if self.is_mistral_format:
+ return self._set_vocab_mistral()
+
+ path_tekken_json = self.dir_model / "tekken.json"
+ path_tokenizer_json = self.dir_model / "tokenizer.json"
+ if path_tekken_json.is_file() and not path_tokenizer_json.is_file():
+ self._set_vocab_mistral()
+
+ try:
+ self._set_vocab_sentencepiece()
+ except FileNotFoundError:
+ try:
+ self._set_vocab_llama_hf()
+ except (FileNotFoundError, TypeError):
+ # Llama 3
+ self._set_vocab_gpt2()
+
+ # Apply to CodeLlama only (and ignore for Llama 3 with a vocab size of 128256)
+ if self.hparams.get("vocab_size", 32000) == 32016:
+ special_vocab = gguf.SpecialVocab(
+ self.dir_model, load_merges=False,
+ special_token_types = ['prefix', 'suffix', 'middle', 'eot']
+ )
+ special_vocab._set_special_token("prefix", 32007)
+ special_vocab._set_special_token("suffix", 32008)
+ special_vocab._set_special_token("middle", 32009)
+ special_vocab._set_special_token("eot", 32010)
+ special_vocab.add_to_gguf(self.gguf_writer)
+
+ tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
+ if tokenizer_config_file.is_file():
+ with open(tokenizer_config_file, "r", encoding="utf-8") as f:
+ tokenizer_config_json = json.load(f)
+ if "add_prefix_space" in tokenizer_config_json:
+ self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
+
+ # Apply to granite small models only
+ if self.hparams.get("vocab_size", 32000) == 49152:
+ self.gguf_writer.add_add_bos_token(False)
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
- self.gguf_writer.add_vocab_size(hparams["vocab_size"])
+
+ if not self.is_mistral_format:
+ self.gguf_writer.add_vocab_size(hparams["vocab_size"])
if (rope_dim := hparams.get("head_dim")) is None:
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
- n_head = self.hparams["num_attention_heads"]
- n_kv_head = self.hparams.get("num_key_value_heads")
+ n_head = self.find_hparam(["n_heads", "num_attention_heads"])
+ n_kv_head = self.find_hparam(["n_kv_heads", "num_key_value_heads"])
+
+ vision_prefixes = [
+ "vision_encoder.",
+ "vision_language_adapter.",
+ "patch_merger.",
+ "pre_mm_projector_norm",
+ ]
+
is_multimodal_tensor = "vision_tower" in name \
or "vision_model" in name \
or "audio_tower" in name \
or "model.connector" in name \
- or "multi_modal_projector" in name
+ or "multi_modal_projector" in name \
+ or any(
+ name.startswith(prefix)
+ for prefix in vision_prefixes
+ )
if is_multimodal_tensor:
return [] # skip vision tensors
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- if self.hparams["model_type"] == "pixtral":
+ if self.hparams.get("model_type") == "pixtral":
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
- logger.info(f"Image break token id: {self.img_break_tok_id}")
+ elif self.is_mistral_format:
+ # hparams is already vision config here so norm_eps is only defined in global_config.
+ self.hparams["norm_eps"] = self.global_config.get("norm_eps", None)
+ assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json"
+ self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
else:
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
+ logger.info(f"Image break token id: {self.img_break_tok_id}")
def get_token_id(self, token: str) -> int:
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
- if hparams["model_type"] == "pixtral":
+ if hparams.get("model_type") == "pixtral":
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PIXTRAL)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
- n_head = self.hparams["num_attention_heads"]
+ n_head = (
+ self.hparams["num_attention_heads"] if not self.is_mistral_format else self.find_vparam(["num_attention_heads"])
+ )
n_kv_head = n_head
- if name.startswith("multi_modal_projector.") or name.startswith("vision_tower."):
+ valid_prefixes = (
+ "multi_modal_projector.",
+ "vision_tower.",
+ "vision_encoder.",
+ "vision_language_adapter.",
+ "patch_merger.",
+ "pre_mm_projector_norm",
+ )
+
+ if any(name.startswith(prefix) for prefix in valid_prefixes):
# process vision tensors
- if name.endswith(("q_proj.weight", "q_proj.bias")):
+ if name.endswith(("q_proj.weight", "q_proj.bias")) and not self.is_mistral_format:
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
- if name.endswith(("k_proj.weight", "k_proj.bias")):
+ if name.endswith(("k_proj.weight", "k_proj.bias")) and not self.is_mistral_format:
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
return [(self.map_tensor_name(name), data_torch)]
- if self.img_break_tok_id > 0 and "embed_tokens.weight" in name:
+ embed_key = "embed_tokens.weight" if not self.is_mistral_format else "tok_embeddings.weight"
+ if self.img_break_tok_id > 0 and embed_key in name:
logger.info(f"Extracting [IMG_BREAK] token embedding from {name}")
# for pixtral model, we need to extract the [IMG_BREAK] token embedding
img_break_embd = data_torch[self.img_break_tok_id]
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
hparams = kwargs.pop("hparams", None)
if hparams is None:
- hparams = ModelBase.load_hparams(dir_model)
+ hparams = ModelBase.load_hparams(dir_model, False)
self.is_moe = bool(hparams.get("moe_every_n_layers"))
self.model_arch = gguf.MODEL_ARCH.NOMIC_BERT_MOE if self.is_moe else gguf.MODEL_ARCH.NOMIC_BERT
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
+
+class MistralModel(LlamaModel):
+ model_arch = gguf.MODEL_ARCH.LLAMA
+ model_name = "Mistral"
+ hf_arch = ""
+ is_mistral_format = True
+ undo_permute = False
+
+ @staticmethod
+ def get_community_chat_template(vocab: MistralVocab, templates_dir: Path):
+ assert TokenizerVersion is not None, "mistral_common is not installed"
+ assert isinstance(vocab.tokenizer, (Tekkenizer, SentencePieceTokenizer)), (
+ f"Expected Tekkenizer or SentencePieceTokenizer, got {type(vocab.tokenizer)}"
+ )
+
+ if vocab.tokenizer.version == TokenizerVersion.v1:
+ return "mistral-v1"
+ elif vocab.tokenizer.version == TokenizerVersion.v3 and vocab.tokenizer_type == MistralTokenizerType.spm:
+ return "mistral-v3"
+ elif vocab.tokenizer.version == TokenizerVersion.v3 and vocab.tokenizer_type == MistralTokenizerType.tekken:
+ return "mistral-v3-tekken"
+ elif vocab.tokenizer.version == TokenizerVersion.v7 and vocab.tokenizer_type == MistralTokenizerType.spm:
+ return "mistral-v7"
+ elif vocab.tokenizer.version == TokenizerVersion.v7 and vocab.tokenizer_type == MistralTokenizerType.tekken:
+ return "mistral-v7-tekken"
+ elif vocab.tokenizer.version == TokenizerVersion.v11:
+ template_file = "Mistral-Small-3.2-24B-Instruct-2506.jinja"
+ elif vocab.tokenizer.version == TokenizerVersion.v13:
+ template_file = "unsloth-mistral-Devstral-Small-2507.jinja"
+ else:
+ raise ValueError(f"Unknown tokenizer type: {vocab.tokenizer_type} and version {vocab.tokenizer.version}")
+
+ template_path = templates_dir / template_file
+ if not template_path.exists():
+ raise FileNotFoundError(f"Template file not found: {template_path}")
+
+ with open(template_path, "r", encoding="utf-8") as f:
+ template = f.read()
+
+ return template
+
+
+class PixtralModel(LlavaVisionModel):
+ model_name = "Pixtral"
+ hf_arch = ""
+ is_mistral_format = True
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PIXTRAL)
+
+ self.gguf_writer.add_vision_attention_layernorm_eps(
+ self.find_hparam(["norm_eps"])
+ )
+ self.gguf_writer.add_rope_freq_base(self.find_vparam(["rope_theta"]))
+
+ self.gguf_writer.add_vision_use_silu(True)
+
+ # spatial_merge_size
+ if self.find_vparam(["mm_projector_id"]) == "patch_merge":
+ self.gguf_writer.add_vision_spatial_merge_size(
+ self.find_vparam(["spatial_merge_size"])
+ )
+
+ def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
+ if name == "vision_language_adapter.w_in.weight":
+ return "mm.1.weight"
+ elif name == "vision_language_adapter.w_out.weight":
+ return "mm.2.weight"
+ return super().map_tensor_name(name, try_suffixes)
+
###### CONVERSION LOGIC ######
"--mmproj", action="store_true",
help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.",
)
+ parser.add_argument(
+ "--mistral-format", action="store_true",
+ help="Whether the model is stored following the Mistral format.",
+ )
args = parser.parse_args()
if not args.print_supported_models and args.model is None:
if "mmproj" not in fname_out.name:
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
+ is_mistral_format = args.mistral_format
+
with torch.inference_mode():
output_type = ftype_map[args.outtype]
model_type = ModelType.MMPROJ if args.mmproj else ModelType.TEXT
- hparams = ModelBase.load_hparams(dir_model)
- model_architecture = get_model_architecture(hparams, model_type)
- logger.info(f"Model architecture: {model_architecture}")
- try:
- model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type)
- except NotImplementedError:
- logger.error(f"Model {model_architecture} is not supported")
- sys.exit(1)
+ hparams = ModelBase.load_hparams(dir_model, is_mistral_format)
+ if not is_mistral_format:
+ model_architecture = get_model_architecture(hparams, model_type)
+ logger.info(f"Model architecture: {model_architecture}")
+ try:
+ model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type)
+ except NotImplementedError:
+ logger.error(f"Model {model_architecture} is not supported")
+ sys.exit(1)
+ elif args.mmproj:
+ assert hparams.get("vision_encoder") is not None, "This model does not support multimodal"
+ model_class = PixtralModel
+ else:
+ model_class = MistralModel
model_instance = model_class(dir_model, output_type, fname_out,
is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split,
- remote_hf_model_id=hf_repo_id)
+ remote_hf_model_id=hf_repo_id,
+ )
if args.vocab_only:
logger.info("Exporting model vocab...")
"model.vision_tower.embeddings.patch_embeddings.projection", # Intern-S1
"vpm.embeddings.patch_embedding",
"model.vision_model.embeddings.patch_embedding", # SmolVLM
- "vision_tower.patch_conv", # pixtral
+ "vision_tower.patch_conv", # pixtral-hf
+ "vision_encoder.patch_conv", # pixtral
"vision_model.patch_embedding.linear", # llama 4
"visual.patch_embed.proj", # qwen2vl
),
"vpm.encoder.layers.{bid}.self_attn.q_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
"vision_model.model.layers.{bid}.self_attn.q_proj", # llama4
- "vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral
+ "vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral-hf
+ "vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral
"visual.blocks.{bid}.attn.q", # qwen2vl, generated
),
"vpm.encoder.layers.{bid}.self_attn.k_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
"vision_model.model.layers.{bid}.self_attn.k_proj", # llama4
- "vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral
+ "vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral-hf
+ "vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral
"visual.blocks.{bid}.attn.k", # qwen2vl, generated
),
"vpm.encoder.layers.{bid}.self_attn.v_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
"vision_model.model.layers.{bid}.self_attn.v_proj", # llama4
- "vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral
+ "vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral-hf
+ "vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral
"visual.blocks.{bid}.attn.v", # qwen2vl, generated
),
"model.vision_tower.encoder.layer.{bid}.layernorm_before", # Intern-S1
"vpm.encoder.layers.{bid}.layer_norm1",
"model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
- "vision_tower.transformer.layers.{bid}.attention_norm", # pixtral
+ "vision_tower.transformer.layers.{bid}.attention_norm", # pixtral-hf
+ "vision_encoder.transformer.layers.{bid}.attention_norm", # pixtral
"vision_model.model.layers.{bid}.input_layernorm", # llama4
"visual.blocks.{bid}.norm1", # qwen2vl
),
"vpm.encoder.layers.{bid}.self_attn.out_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
"vision_model.model.layers.{bid}.self_attn.o_proj", # llama4
- "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral
+ "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral-hf
+ "vision_encoder.transformer.layers.{bid}.attention.wo", # pixtral
"visual.blocks.{bid}.attn.proj", # qwen2vl
),
"vpm.encoder.layers.{bid}.layer_norm2",
"model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
"vision_model.model.layers.{bid}.post_attention_layernorm", # llama4
- "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral
+ "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral-hf
+ "vision_encoder.transformer.layers.{bid}.ffn_norm", # pixtral
"visual.blocks.{bid}.norm2", # qwen2vl
),
"model.vision_tower.encoder.layer.{bid}.mlp.fc1", # Intern-S1
"vpm.encoder.layers.{bid}.mlp.fc1",
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3
- "vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral
+ "vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral-hf
+ "vision_encoder.transformer.layers.{bid}.feed_forward.w3", # pixtral
"vision_model.model.layers.{bid}.mlp.fc1", # llama4
"visual.blocks.{bid}.mlp.fc1", # qwen2vl
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
),
MODEL_TENSOR.V_ENC_FFN_GATE: (
- "vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral
+ "vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral-hf
+ "vision_encoder.transformer.layers.{bid}.feed_forward.w1", # pixtral
"visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl
),
"model.vision_tower.encoder.layer.{bid}.mlp.fc2", # Intern-S1
"vpm.encoder.layers.{bid}.mlp.fc2",
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3
- "vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral
+ "vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral-hf
+ "vision_encoder.transformer.layers.{bid}.feed_forward.w2", # pixtral
"vision_model.model.layers.{bid}.mlp.fc2", # llama4
"visual.blocks.{bid}.mlp.fc2", # qwen2vl
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
MODEL_TENSOR.V_PRE_NORM: (
"vision_tower.vision_model.pre_layrnorm",
- "vision_tower.ln_pre", # pixtral
+ "vision_tower.ln_pre", # pixtral-hf
+ "vision_encoder.ln_pre", # pixtral
"vision_model.layernorm_pre", # llama4
),
MODEL_TENSOR.V_MM_INP_NORM: (
"multi_modal_projector.norm",
+ "pre_mm_projector_norm",
),
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
),
MODEL_TENSOR.V_MM_PATCH_MERGER: (
- "multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1
+ "multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1 - hf
+ "patch_merger.merging_layer", # mistral
),
# audio (mtmd)