]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
convert : improve model arch handling (#13122)
authorXuan-Son Nguyen <redacted>
Wed, 30 Apr 2025 14:56:24 +0000 (16:56 +0200)
committerGitHub <redacted>
Wed, 30 Apr 2025 14:56:24 +0000 (16:56 +0200)
* convert : improve model arch handling

* use AutoConfig

* rm trust_remote_code

* Update convert_hf_to_gguf.py

* fix self.block_count for vision

* fix NomicBertModel

convert_hf_to_gguf.py

index d607af6955987b9500ff24f5c02ae0cdc6cc003d..123df801bf095b67e768e1ca1311e939c475425b 100755 (executable)
@@ -16,6 +16,7 @@ from pathlib import Path
 from hashlib import sha256
 from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
 from itertools import chain
+from transformers import AutoConfig
 
 import math
 import numpy as np
@@ -66,8 +67,6 @@ class ModelBase:
     part_names: list[str]
     is_safetensors: bool
     hparams: dict[str, Any]
-    block_count: int
-    tensor_map: gguf.TensorNameMap
     tensor_names: set[str] | None
     gguf_writer: gguf.GGUFWriter
     model_name: str | None
@@ -78,6 +77,10 @@ class ModelBase:
     # subclasses should define this!
     model_arch: gguf.MODEL_ARCH
 
+    # subclasses should initialize this!
+    block_count: int
+    tensor_map: gguf.TensorNameMap
+
     def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
                  use_temp_file: bool = False, eager: bool = False,
                  metadata_override: Path | None = None, model_name: str | None = None,
@@ -113,8 +116,6 @@ class ModelBase:
             if not self.is_safetensors:
                 self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
         self.hparams = ModelBase.load_hparams(self.dir_model) if hparams is None else hparams
-        self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
-        self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
         self.tensor_names = None
         self.metadata_override = metadata_override
         self.model_name = model_name
@@ -417,15 +418,13 @@ class ModelBase:
 
     @staticmethod
     def load_hparams(dir_model: Path):
-        with open(dir_model / "config.json", "r", encoding="utf-8") as f:
-            hparams = json.load(f)
-            architectures = hparams.get("architectures")
-            if "text_config" in hparams:
-                hparams = {**hparams, **hparams["text_config"]}
-            if architectures is not None:
-                # preserve "architectures" from root level config
-                hparams["architectures"] = architectures
-            return hparams
+        try:
+            return AutoConfig.from_pretrained(dir_model).to_dict()
+        except Exception as e:
+            logger.warning(f"Failed to load model config from {dir_model}: {e}")
+            logger.warning("Trying to load config.json instead")
+            with open(dir_model / "config.json", "r", encoding="utf-8") as f:
+                return json.load(f)
 
     @classmethod
     def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
@@ -454,6 +453,23 @@ class ModelBase:
 
 
 class TextModel(ModelBase):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        if "text_config" in self.hparams:
+            # move the text_config to the root level
+            self.hparams = {**self.hparams, **self.hparams["text_config"]}
+
+        self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
+        self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
+
+    @classmethod
+    def __init_subclass__(cls):
+        # can't use an abstract property, because overriding it without type errors
+        # would require using decorated functions instead of simply defining the property
+        if "model_arch" not in cls.__dict__:
+            raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}")
+
     def set_vocab(self):
         self._set_vocab_gpt2()
 
@@ -1070,9 +1086,9 @@ class VisionModel(ModelBase):
         if self.model_arch != gguf.MODEL_ARCH.CLIP_VISION:
             raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION")
 
-        # small hack to correct the number of layers
-        self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, 128)
-        self.n_embd_text = self.find_hparam(["hidden_size", "n_embd"])
+        # get n_embd of the text model
+        text_config = {**self.hparams, **self.hparams["text_config"]}
+        self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
         assert self.n_embd_text > 0, "n_embd not found in hparams"
 
         if "vision_config" not in self.hparams:
@@ -1081,6 +1097,9 @@ class VisionModel(ModelBase):
         self.global_config = self.hparams
         self.hparams = self.hparams["vision_config"]
 
+        self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"])
+        self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, self.block_count)
+
         # load preprocessor config
         with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
             self.preprocessor_config = json.load(f)
@@ -1098,7 +1117,7 @@ class VisionModel(ModelBase):
         self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
         self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
         self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
-        self.gguf_writer.add_vision_block_count(self.find_hparam(["num_hidden_layers"]))
+        self.gguf_writer.add_vision_block_count(self.block_count)
         self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
 
         # preprocessor config
@@ -1719,23 +1738,12 @@ class StableLMModel(TextModel):
     "LlamaForCausalLM",
     "MistralForCausalLM",
     "MixtralForCausalLM",
-    "Idefics3ForConditionalGeneration",
-    "SmolVLMForConditionalGeneration",
+    "VLlama3ForCausalLM",
     "LlavaForConditionalGeneration")
 class LlamaModel(TextModel):
     model_arch = gguf.MODEL_ARCH.LLAMA
     undo_permute = True
 
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        # fix for SmolVLM2, missing `num_attention_heads` in config.json
-        if self.hparams["architectures"][0] == "SmolVLMForConditionalGeneration":
-            self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
-        # fix for Pixtral, missing `num_attention_heads` in config.json
-        if self.hparams["architectures"][0] == "LlavaForConditionalGeneration" \
-                and self.hparams.get("model_type") == "mistral":
-            self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
-
     def set_vocab(self):
         try:
             self._set_vocab_sentencepiece()
@@ -1898,11 +1906,7 @@ class LlavaVisionModel(VisionModel):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         if self.hparams["model_type"] == "pixtral":
-            # fix missing config.json values
-            self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
-            self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 24)
-            self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 4096)
-            self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1024)
+            # layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
             self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
             self.img_break_tok_id = 12 # see tokenizer_config.json
         else:
@@ -1913,7 +1917,6 @@ class LlavaVisionModel(VisionModel):
         hparams = self.hparams
         if hparams["model_type"] == "pixtral":
             self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
-            # default values below are taken from HF tranformers code
             self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
             self.gguf_writer.add_vision_use_silu(True)
 
@@ -1944,13 +1947,12 @@ class LlavaVisionModel(VisionModel):
 class SmolVLMModel(VisionModel):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
-        # fix for SmolVLM2, missing some keys in config.json
-        # default values are taken from transformers code
         if self.hparams["model_type"] == "smolvlm_vision":
+            # fix for SmolVLM2, missing some keys in config.json
+            # default values are taken from transformers code
             self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1152)
             self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
             self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 3072)
-            self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 12)
 
     def set_gguf_parameters(self):
         super().set_gguf_parameters()
@@ -3505,6 +3507,8 @@ class RobertaModel(BertModel):
 
 @ModelBase.register("NomicBertModel")
 class NomicBertModel(BertModel):
+    model_arch = gguf.MODEL_ARCH.BERT
+
     def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
         hparams = kwargs.pop("hparams", None)
         if hparams is None:
@@ -5849,6 +5853,19 @@ def split_str_to_n_bytes(split_str: str) -> int:
     return n
 
 
+def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any = None) -> str:
+    hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams
+    text_config = hparams.get("text_config", {})
+    vision_config = hparams.get("vision_config", {})
+    arch = hparams["architectures"][0]
+    # if "architectures" is found in the sub-config, use that instead
+    if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
+        arch = text_config["architectures"][0]
+    elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
+        arch = vision_config["architectures"][0]
+    return arch
+
+
 def main() -> None:
     args = parse_args()
 
@@ -5901,16 +5918,15 @@ def main() -> None:
 
     logger.info(f"Loading model: {dir_model.name}")
 
-    hparams = ModelBase.load_hparams(dir_model)
-
     if args.mmproj:
         if "mmproj" not in fname_out.name:
             fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
 
     with torch.inference_mode():
         output_type = ftype_map[args.outtype]
-        model_architecture = hparams["architectures"][0]
         model_type = ModelType.VISION if args.mmproj else ModelType.TEXT
+        model_architecture = get_model_architecture(dir_model, model_type)
+        logger.info(f"Model architecture: {model_architecture}")
         try:
             model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type)
         except NotImplementedError: