]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
convert : converting mmproj for Qwen2/2.5VL from convert_hf_to_gguf (#13209)
authorXuan-Son Nguyen <redacted>
Fri, 2 May 2025 15:17:15 +0000 (17:17 +0200)
committerGitHub <redacted>
Fri, 2 May 2025 15:17:15 +0000 (17:17 +0200)
* wip

* qwen2.5vl ok

* vision: fix models missing "text_config"

* add test

* fix test repo name

* fix 32B model

* Revert "fix 32B model"

This reverts commit 651752f1ae25fe8a01c1e57c18cf2eca80b2774e.

* clarify about 32B

* rm qwen surgery script

* update llava/readme

* move V_ENC_EMBD_PATCH handling to Qwen2VLVisionModel

convert_hf_to_gguf.py
examples/llava/README.md
examples/llava/qwen2_vl_surgery.py [deleted file]
examples/llava/tests.sh
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/tensor_mapping.py

index df3f8a55d5320cffd95a31cc71f65d66b9480eb6..ff82a85a9d7cd464fabbead5e4084fb58f1fa699 100755 (executable)
@@ -1089,6 +1089,8 @@ class VisionModel(ModelBase):
             raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION")
 
         # get n_embd of the text model
+        if "text_config" not in self.hparams:
+            self.hparams["text_config"] = {}
         text_config = {**self.hparams, **self.hparams["text_config"]}
         self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
         assert self.n_embd_text > 0, "n_embd not found in hparams"
@@ -2583,6 +2585,82 @@ class Qwen2VLModel(TextModel):
         return [(self.map_tensor_name(name), data_torch)]
 
 
+@ModelBase.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
+class Qwen2VLVisionModel(VisionModel):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.hparams["image_size"] = self.hparams.get("image_size", 560)
+        # rename config.json values
+        self.hparams["num_attention_heads"] = self.hparams.get("num_heads")
+        self.hparams["num_hidden_layers"] = self.hparams.get("depth")
+        if "embed_dim" in self.hparams: # qwen2vl
+            self.hparams["intermediate_size"] = self.hparams.get("hidden_size")
+            self.hparams["hidden_size"] = self.hparams.get("embed_dim")
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        hparams = self.hparams
+        if self.global_config['model_type'] == 'qwen2_vl':
+            self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN2VL)
+        elif self.global_config['model_type'] == 'qwen2_5_vl':
+            self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN25VL)
+            self.gguf_writer.add_vision_use_silu(True)
+            # find n_wa_pattern (window attention pattern)
+            fullatt_block_indexes = hparams.get("fullatt_block_indexes")
+            assert fullatt_block_indexes is not None, "fullatt_block_indexes is required for qwen2_5_vl"
+            n_wa_pattern = fullatt_block_indexes[0] + 1
+            # validate n_wa_pattern
+            for i in range(1, len(fullatt_block_indexes)):
+                if fullatt_block_indexes[i] - fullatt_block_indexes[i - 1] != n_wa_pattern:
+                    raise ValueError(f"Invalid fullatt_block_indexes: {fullatt_block_indexes}")
+            self.gguf_writer.add_vision_n_wa_pattern(n_wa_pattern)
+        else:
+            raise ValueError(f"Unknown QwenVL model type: {self.global_config['model_type']}")
+        # default values below are taken from HF tranformers code
+        self.gguf_writer.add_vision_attention_layernorm_eps(self.global_config.get("rms_norm_eps", 1e-6))
+
+    def tensor_force_quant(self, name, new_name, bid, n_dims):
+        del bid, name, n_dims  # unused
+        if ".patch_embd." in new_name:
+            return gguf.GGMLQuantizationType.F16
+        if ".position_embd." in new_name:
+            return gguf.GGMLQuantizationType.F32
+        return False
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        del bid  # unused
+        if name.startswith("visual."):
+            # process visual tensors
+            # split QKV tensors if needed
+            if ".qkv." in name:
+                if data_torch.ndim == 2: # weight
+                    c3, _ = data_torch.shape
+                else: # bias
+                    c3 = data_torch.shape[0]
+                assert c3 % 3 == 0
+                c = c3 // 3
+                wq = data_torch[:c]
+                wk = data_torch[c: c * 2]
+                wv = data_torch[c * 2:]
+                return [
+                    (self.map_tensor_name(name.replace("qkv", "q")), wq),
+                    (self.map_tensor_name(name.replace("qkv", "k")), wk),
+                    (self.map_tensor_name(name.replace("qkv", "v")), wv),
+                ]
+            elif 'patch_embed.proj.weight' in name:
+                # split Conv3D into Conv2Ds
+                c1, c2, kt, kh, kw = data_torch.shape
+                del c1, c2, kh, kw  # unused
+                assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
+                return [
+                    (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight"  , data_torch[:, :, 0, ...]),
+                    (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]),
+                ]
+            else:
+                return [(self.map_tensor_name(name), data_torch)]
+        return [] # skip other tensors
+
+
 @ModelBase.register("WavTokenizerDec")
 class WavTokenizerDecModel(TextModel):
     model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC
index 3b62627ce829fec6d88fa7e65be643463b8f396e..b97b9e8c54367e39834bd0201fe51a88558ba0d3 100644 (file)
@@ -35,6 +35,16 @@ llama-mtmd-cli -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF
 # Pixtral 12B
 llama-mtmd-cli -hf ggml-org/pixtral-12b-GGUF
 
+# Qwen 2 VL
+llama-mtmd-cli -hf ggml-org/Qwen2-VL-2B-Instruct-GGUF
+llama-mtmd-cli -hf ggml-org/Qwen2-VL-7B-Instruct-GGUF
+
+# Qwen 2.5 VL
+llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-3B-Instruct-GGUF
+llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-7B-Instruct-GGUF
+llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-32B-Instruct-GGUF
+llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-72B-Instruct-GGUF
+
 # Mistral Small 3.1 24B (IQ2_M quantization)
 llama-mtmd-cli -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF --chat-template mistral-v7
 ```
@@ -60,7 +70,17 @@ Built upon `clip.cpp` (similar to `llava.cpp`), `libmtmd` offers several advanta
 
 ## How to obtain `mmproj`
 
-Multimodal projector (`mmproj`) files are specific to each model architecture. Please refer to the relevant guide for instructions on how to obtain or create them:
+Multimodal projector (`mmproj`) files are specific to each model architecture.
+
+For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` flag to get the `mmproj` file:
+- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) - Note: 1B variant does not have vision support
+- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
+- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
+- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
+- Qwen 2 VL and Qwen 2.5 VL (from [Qwen](https://huggingface.co/Qwen))
+- [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)
+
+For older models, please refer to the relevant guide for instructions on how to obtain or create them:
 
 - [LLaVA](../../docs/multimodal/llava.md)
 - [MobileVLM](../../docs/multimodal/MobileVLM.md)
@@ -70,10 +90,3 @@ Multimodal projector (`mmproj`) files are specific to each model architecture. P
 - [MiniCPM-o 2.6](../../docs/multimodal/minicpmo2.6.md)
 - [IBM Granite Vision](../../docs/multimodal/granitevision.md)
 - [Google Gemma 3](../../docs/multimodal/gemma3.md)
-
-For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` flag to get the `mmproj` file:
-- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) - Note: 1B variant does not have vision support
-- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
-- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
-- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
-- [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)
diff --git a/examples/llava/qwen2_vl_surgery.py b/examples/llava/qwen2_vl_surgery.py
deleted file mode 100644 (file)
index 7951a6f..0000000
+++ /dev/null
@@ -1,217 +0,0 @@
-import argparse
-from typing import Dict, List, Optional
-
-import torch
-import numpy as np
-from gguf import *
-from transformers import (
-    AutoProcessor,
-    Qwen2VLConfig,
-    Qwen2VLProcessor,
-    Qwen2VLForConditionalGeneration,
-    Qwen2_5_VLConfig, # type: ignore[reportAttributeAccessIssue]
-    Qwen2_5_VLForConditionalGeneration, # type: ignore[reportAttributeAccessIssue]
-)
-
-
-VISION = "clip.vision"
-
-
-def k(raw_key: str, arch: str) -> str:
-    return raw_key.format(arch=arch)
-
-
-def get_n_wa_pattern(fullatt_block_indexes: Optional[List[int]]):
-    if fullatt_block_indexes is None:
-        return 0
-    n_wa = fullatt_block_indexes[0]
-    for a, b in zip(fullatt_block_indexes, fullatt_block_indexes[1:]):
-        if b - a - 1 != n_wa:
-            raise ValueError(
-                f"window/full attention layer should have fix pattern of "
-                f"for each full-attention layer followed by {n_wa} window-attention layers"
-            )
-    return n_wa + 1
-
-
-class VL2:
-
-    @staticmethod
-    def to_gguf_name(name: str) -> str:
-        og = name
-        name = name.replace("text_model", "t").replace("vision_model", "v")
-        name = name.replace("blocks", "blk").replace("embeddings.", "")
-        name = name.replace("attn.", "attn_")
-        name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.")
-        # name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln")
-        name = name.replace("norm1", "ln1").replace("norm2", "ln2")
-        name = name.replace("merger.mlp", 'mm')
-        print(f"[to_gguf_name] {og} --> {name}")
-        return name
-
-    @classmethod
-    def find_vision_tensors(cls, qwen2vl, dtype) -> Dict[str, np.ndarray]:
-        vision_model = qwen2vl.visual
-        tensor_map = {}
-        for name, ten in vision_model.state_dict().items():
-            ten = ten.numpy()
-            if 'qkv' in name:
-                if ten.ndim == 2: # weight
-                    c3, _ = ten.shape
-                else:             # bias
-                    c3 = ten.shape[0]
-                assert c3 % 3 == 0
-                c = c3 // 3
-                wq = ten[:c]
-                wk = ten[c: c * 2]
-                wv = ten[c * 2:]
-                tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq
-                tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk
-                tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv
-            elif 'merger' in name:
-                if name.endswith("ln_q.weight"):
-                    tensor_map['v.post_ln.weight'] = ten
-                elif name.endswith("ln_q.bias"):
-                    tensor_map['v.post_ln.bias'] = ten
-                else:
-                    # "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias"
-                    tensor_map[cls.to_gguf_name(name)] = ten
-            elif 'patch_embed.proj.weight' in name:
-                # NOTE: split Conv3D into Conv2Ds
-                c1, c2, kt, kh, kw = ten.shape
-                assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
-                tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...]
-                tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...]
-            else:
-                tensor_map[cls.to_gguf_name(f"vision_model.{name}")] = ten
-
-        for new_name, ten in tensor_map.items():
-            if ten.ndim <= 1 or new_name.endswith("_norm.weight"):
-                tensor_map[new_name] = ten.astype(np.float32)
-            else:
-                tensor_map[new_name] = ten.astype(dtype)
-        tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32)  # dummy tensor, just here as a placeholder
-        return tensor_map
-
-
-class VL25(VL2):
-
-    @staticmethod
-    def to_gguf_name(name: str) -> str:
-        og = name
-        name = name.replace("text_model", "t").replace("vision_model", "v")
-        name = name.replace("blocks", "blk").replace("embeddings.", "")
-        name = name.replace("attn.", "attn_")
-        name = name.replace("mlp.down_proj", "ffn_down").replace("mlp.up_proj", "ffn_up")
-        name = name.replace("mlp.gate_proj", "ffn_gate").replace("proj.", "out.")
-        name = name.replace("norm1", "ln1").replace("norm2", "ln2")
-        name = name.replace("merger.mlp", 'mm')
-        print(f"[vl25][to_gguf_name] {og} --> {name}")
-        return name
-
-
-def main(args):
-    if args.data_type == 'fp32':
-        dtype = torch.float32
-        np_dtype = np.float32
-        ftype = 0
-    elif args.data_type == 'fp16':
-        dtype = torch.float16
-        np_dtype = np.float16
-        ftype = 1
-    else:
-        raise ValueError()
-
-    local_model = False
-    model_path = ""
-    model_name = args.model_name
-    print("model_name: ", model_name)
-    if args.model_type == "qwen2vl":
-        qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
-            model_name, torch_dtype=dtype, device_map="cpu"
-        )
-        cfg: Qwen2VLConfig = qwen2vl.config  # type: ignore[reportAssignmentType]
-        vcfg = cfg.vision_config
-    else:
-        qwen2vl = Qwen2_5_VLForConditionalGeneration.from_pretrained(
-            model_name, torch_dtype=dtype, device_map="cpu"
-        )
-        cfg: Qwen2_5_VLConfig = qwen2vl.config  # type: ignore[reportAssignmentType]
-        vcfg = cfg.vision_config
-
-    if os.path.isdir(model_name):
-        local_model = True
-        if model_name.endswith(os.sep):
-            model_name = model_name[:-1]
-        model_path = model_name
-        model_name = os.path.basename(model_name)
-    fname_out = f"{model_name.replace('/', '-').lower()}-vision.gguf"
-
-    fout = GGUFWriter(path=fname_out, arch="clip")
-    fout.add_description("image encoder for Qwen2VL")
-
-    fout.add_file_type(ftype)
-    fout.add_bool("clip.has_text_encoder", False)
-    fout.add_bool("clip.has_vision_encoder", True)
-    fout.add_bool("clip.has_qwen2vl_merger", True)
-
-    print(cfg.vision_config)
-    if 'silu' in cfg.vision_config.hidden_act.lower():
-        fout.add_bool("clip.use_silu", True)
-        fout.add_bool("clip.use_gelu", False)
-    elif 'gelu' in cfg.vision_config.hidden_act.lower():
-        fout.add_bool("clip.use_silu", False)
-        fout.add_bool("clip.use_gelu", 'quick' not in cfg.vision_config.hidden_act.lower())
-    else:
-        raise ValueError()
-
-    if args.model_type == "qwen2.5vl":
-        fout.add_uint32("clip.vision.n_wa_pattern", get_n_wa_pattern(vcfg.fullatt_block_indexes))
-        fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size)
-        fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size)
-        fout.add_string("clip.projector_type", "qwen2.5vl_merger")
-    else:
-        fout.add_string("clip.projector_type", "qwen2vl_merger")
-        fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
-        fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
-
-    if args.model_type == "qwen2.5vl":
-        tensor_map = VL25.find_vision_tensors(qwen2vl, np_dtype)
-    else:
-        tensor_map = VL2.find_vision_tensors(qwen2vl, np_dtype)
-    for name, data in tensor_map.items():
-        fout.add_tensor(name, data)
-
-    fout.add_uint32("clip.vision.patch_size", vcfg.patch_size)
-    fout.add_uint32("clip.vision.image_size", 14 * 40)  # some reasonable size that is divable by (14*2)
-    fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads)
-    fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
-    fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth)
-    fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 0)  # not sure what this does, put 0 here as a placeholder
-    fout.add_name(model_name)
-    """
-    HACK: Since vision rope related parameter aren't stored in the `Qwen2VLConfig,
-            it will be hardcoded in the `clip_image_build_graph` from `clip.cpp`.
-    """
-
-    if local_model:
-        processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_path)
-    else:
-        processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_name)
-    fout.add_array("clip.vision.image_mean", processor.image_processor.image_mean) # type: ignore[reportAttributeAccessIssue]
-    fout.add_array("clip.vision.image_std", processor.image_processor.image_std) # type: ignore[reportAttributeAccessIssue]
-
-    fout.write_header_to_file()
-    fout.write_kv_data_to_file()
-    fout.write_tensors_to_file()
-    fout.close()
-    print("save model as: ", fname_out)
-
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser()
-    parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct")
-    parser.add_argument("--model_type", nargs='?', choices=['qwen2vl', 'qwen2.5vl'], default="qwen2vl")
-    parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32")
-    args = parser.parse_args()
-    main(args)
index 4af370064086f94ce8bc390a1ef3f9d5c210b19f..22c23749713316d14e13fd206ce4d60183935740 100755 (executable)
@@ -36,12 +36,6 @@ add_test() {
     arr_tmpl+=("$tmpl")
 }
 
-add_test_big() {
-    if [ "$RUN_BIG_TESTS" = true ]; then
-        add_test "$@"
-    fi
-}
-
 add_test "llama-mtmd-cli"  "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0"
 add_test "llama-mtmd-cli"  "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M"
 add_test "llama-mtmd-cli"  "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0"
@@ -58,8 +52,16 @@ add_test "llama-mtmd-cli"  "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
 add_test "llama-mtmd-cli"  "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
 
 # to test the big models, run: ./tests.sh big
-add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
-add_test_big "llama-mtmd-cli" "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7"
+if [ "$RUN_BIG_TESTS" = true ]; then
+    add_test "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
+    add_test "llama-mtmd-cli" "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7"
+    add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
+    add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-7B-Instruct-GGUF:Q4_K_M"
+    add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
+    add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q4_K_M"
+    # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra
+    # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M" # too big
+fi
 
 # these models always give the wrong answer, not sure why
 # add_test "llama-mtmd-cli"  "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"
index a2540bd93fd91ea9fd88f1e1850e38c6fdc958aa..74e46c3ee0f9523f00609efb8ba91b546896677a 100644 (file)
@@ -234,6 +234,7 @@ class Keys:
         SPATIAL_MERGE_SIZE  = "clip.vision.spatial_merge_size"
         USE_GELU            = "clip.use_gelu"
         USE_SILU            = "clip.use_silu"
+        N_WA_PATTERN        = "clip.vision.n_wa_pattern" # used by qwen2.5vl
 
         class Attention:
             HEAD_COUNT      = "clip.vision.attention.head_count"
@@ -2162,6 +2163,8 @@ class VisionProjectorType:
     GEMMA3 = "gemma3"
     IDEFICS3 = "idefics3"
     PIXTRAL = "pixtral"
+    QWEN2VL = "qwen2vl_merger"
+    QWEN25VL = "qwen2.5vl_merger"
 
 
 # Items here are (block size, type size)
index a30c49e32b35180a2b9df82b7761370f4dac86b1..ff50d3de31287bf6aa27a609a3e9846e8a5f5336 100644 (file)
@@ -984,6 +984,9 @@ class GGUFWriter:
     def add_vision_projector_scale_factor(self, value: int) -> None:
         self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value)
 
+    def add_vision_n_wa_pattern(self, value: int) -> None:
+        self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
+
     def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
         pack_prefix = ''
         if not skip_pack_prefix:
index 2f6326104ffa70f5aa1e5e7e3933921683f272a0..2b089f84a841ad9d064db3f86b4c1ba4205e4187 100644 (file)
@@ -896,6 +896,7 @@ class TensorNameMap:
 
         MODEL_TENSOR.V_MMPROJ: (
             "multi_modal_projector.linear_{bid}",
+            "visual.merger.mlp.{bid}", # qwen2vl
         ),
 
         MODEL_TENSOR.V_MMPROJ_FC: (
@@ -919,6 +920,7 @@ class TensorNameMap:
             "vpm.embeddings.patch_embedding",
             "model.vision_model.embeddings.patch_embedding", # SmolVLM
             "vision_tower.patch_conv", # pixtral
+            "visual.patch_embed.proj", # qwen2vl
         ),
 
         MODEL_TENSOR.V_ENC_EMBD_POS: (
@@ -932,6 +934,7 @@ class TensorNameMap:
             "vpm.encoder.layers.{bid}.self_attn.q_proj",
             "model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
             "vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral
+            "visual.blocks.{bid}.attn.q", # qwen2vl, generated
         ),
 
         MODEL_TENSOR.V_ENC_ATTN_K: (
@@ -939,6 +942,7 @@ class TensorNameMap:
             "vpm.encoder.layers.{bid}.self_attn.k_proj",
             "model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
             "vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral
+            "visual.blocks.{bid}.attn.k", # qwen2vl, generated
         ),
 
         MODEL_TENSOR.V_ENC_ATTN_V: (
@@ -946,6 +950,7 @@ class TensorNameMap:
             "vpm.encoder.layers.{bid}.self_attn.v_proj",
             "model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
             "vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral
+            "visual.blocks.{bid}.attn.v", # qwen2vl, generated
         ),
 
         MODEL_TENSOR.V_ENC_INPUT_NORM: (
@@ -953,6 +958,7 @@ class TensorNameMap:
             "vpm.encoder.layers.{bid}.layer_norm1",
             "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
             "vision_tower.transformer.layers.{bid}.attention_norm", # pixtral
+            "visual.blocks.{bid}.norm1", # qwen2vl
         ),
 
         MODEL_TENSOR.V_ENC_OUTPUT: (
@@ -960,6 +966,7 @@ class TensorNameMap:
             "vpm.encoder.layers.{bid}.self_attn.out_proj",
             "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
             "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral
+            "visual.blocks.{bid}.attn.proj", # qwen2vl
         ),
 
         MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
@@ -967,17 +974,24 @@ class TensorNameMap:
             "vpm.encoder.layers.{bid}.layer_norm2",
             "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
             "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral
+            "visual.blocks.{bid}.norm2", # qwen2vl
         ),
 
+        # some namings are messed up because the original llava code swapped fc1 and fc2
+        # we have no better way to fix it, just be careful
+        # new models like pixtral use the correct naming
         MODEL_TENSOR.V_ENC_FFN_UP: (
             "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
             "vpm.encoder.layers.{bid}.mlp.fc1",
             "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 (note: name is swapped)
             "vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral
+            "visual.blocks.{bid}.mlp.fc2", # qwen2vl
+            "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
         ),
 
         MODEL_TENSOR.V_ENC_FFN_GATE: (
             "vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral
+            "visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl
         ),
 
         MODEL_TENSOR.V_ENC_FFN_DOWN: (
@@ -985,6 +999,8 @@ class TensorNameMap:
             "vpm.encoder.layers.{bid}.mlp.fc2",
             "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 (note: name is swapped)
             "vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral
+            "visual.blocks.{bid}.mlp.fc1", # qwen2vl
+            "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
         ),
 
         MODEL_TENSOR.V_PRE_NORM: (
@@ -995,6 +1011,7 @@ class TensorNameMap:
         MODEL_TENSOR.V_POST_NORM: (
             "vision_tower.vision_model.post_layernorm",
             "model.vision_model.post_layernorm", # SmolVLM
+            "visual.merger.ln_q", # qwen2vl
         ),
 
         MODEL_TENSOR.V_MM_INP_PROJ: (