from hashlib import sha256
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
from itertools import chain
+from transformers import AutoConfig
import math
import numpy as np
part_names: list[str]
is_safetensors: bool
hparams: dict[str, Any]
- block_count: int
- tensor_map: gguf.TensorNameMap
tensor_names: set[str] | None
gguf_writer: gguf.GGUFWriter
model_name: str | None
# subclasses should define this!
model_arch: gguf.MODEL_ARCH
+ # subclasses should initialize this!
+ block_count: int
+ tensor_map: gguf.TensorNameMap
+
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
if not self.is_safetensors:
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
self.hparams = ModelBase.load_hparams(self.dir_model) if hparams is None else hparams
- self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
- self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
self.tensor_names = None
self.metadata_override = metadata_override
self.model_name = model_name
@staticmethod
def load_hparams(dir_model: Path):
- with open(dir_model / "config.json", "r", encoding="utf-8") as f:
- hparams = json.load(f)
- architectures = hparams.get("architectures")
- if "text_config" in hparams:
- hparams = {**hparams, **hparams["text_config"]}
- if architectures is not None:
- # preserve "architectures" from root level config
- hparams["architectures"] = architectures
- return hparams
+ try:
+ return AutoConfig.from_pretrained(dir_model).to_dict()
+ except Exception as e:
+ logger.warning(f"Failed to load model config from {dir_model}: {e}")
+ logger.warning("Trying to load config.json instead")
+ with open(dir_model / "config.json", "r", encoding="utf-8") as f:
+ return json.load(f)
@classmethod
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
class TextModel(ModelBase):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ if "text_config" in self.hparams:
+ # move the text_config to the root level
+ self.hparams = {**self.hparams, **self.hparams["text_config"]}
+
+ self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
+ self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
+
+ @classmethod
+ def __init_subclass__(cls):
+ # can't use an abstract property, because overriding it without type errors
+ # would require using decorated functions instead of simply defining the property
+ if "model_arch" not in cls.__dict__:
+ raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}")
+
def set_vocab(self):
self._set_vocab_gpt2()
if self.model_arch != gguf.MODEL_ARCH.CLIP_VISION:
raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION")
- # small hack to correct the number of layers
- self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, 128)
- self.n_embd_text = self.find_hparam(["hidden_size", "n_embd"])
+ # get n_embd of the text model
+ text_config = {**self.hparams, **self.hparams["text_config"]}
+ self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
assert self.n_embd_text > 0, "n_embd not found in hparams"
if "vision_config" not in self.hparams:
self.global_config = self.hparams
self.hparams = self.hparams["vision_config"]
+ self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"])
+ self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, self.block_count)
+
# load preprocessor config
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
self.preprocessor_config = json.load(f)
self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
- self.gguf_writer.add_vision_block_count(self.find_hparam(["num_hidden_layers"]))
+ self.gguf_writer.add_vision_block_count(self.block_count)
self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
# preprocessor config
"LlamaForCausalLM",
"MistralForCausalLM",
"MixtralForCausalLM",
- "Idefics3ForConditionalGeneration",
- "SmolVLMForConditionalGeneration",
+ "VLlama3ForCausalLM",
"LlavaForConditionalGeneration")
class LlamaModel(TextModel):
model_arch = gguf.MODEL_ARCH.LLAMA
undo_permute = True
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- # fix for SmolVLM2, missing `num_attention_heads` in config.json
- if self.hparams["architectures"][0] == "SmolVLMForConditionalGeneration":
- self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
- # fix for Pixtral, missing `num_attention_heads` in config.json
- if self.hparams["architectures"][0] == "LlavaForConditionalGeneration" \
- and self.hparams.get("model_type") == "mistral":
- self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
-
def set_vocab(self):
try:
self._set_vocab_sentencepiece()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.hparams["model_type"] == "pixtral":
- # fix missing config.json values
- self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
- self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 24)
- self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 4096)
- self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1024)
+ # layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
self.img_break_tok_id = 12 # see tokenizer_config.json
else:
hparams = self.hparams
if hparams["model_type"] == "pixtral":
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
- # default values below are taken from HF tranformers code
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
self.gguf_writer.add_vision_use_silu(True)
class SmolVLMModel(VisionModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- # fix for SmolVLM2, missing some keys in config.json
- # default values are taken from transformers code
if self.hparams["model_type"] == "smolvlm_vision":
+ # fix for SmolVLM2, missing some keys in config.json
+ # default values are taken from transformers code
self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1152)
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 3072)
- self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 12)
def set_gguf_parameters(self):
super().set_gguf_parameters()
@ModelBase.register("NomicBertModel")
class NomicBertModel(BertModel):
+ model_arch = gguf.MODEL_ARCH.BERT
+
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
hparams = kwargs.pop("hparams", None)
if hparams is None:
return n
+def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any = None) -> str:
+ hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams
+ text_config = hparams.get("text_config", {})
+ vision_config = hparams.get("vision_config", {})
+ arch = hparams["architectures"][0]
+ # if "architectures" is found in the sub-config, use that instead
+ if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
+ arch = text_config["architectures"][0]
+ elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
+ arch = vision_config["architectures"][0]
+ return arch
+
+
def main() -> None:
args = parse_args()
logger.info(f"Loading model: {dir_model.name}")
- hparams = ModelBase.load_hparams(dir_model)
-
if args.mmproj:
if "mmproj" not in fname_out.name:
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
with torch.inference_mode():
output_type = ftype_map[args.outtype]
- model_architecture = hparams["architectures"][0]
model_type = ModelType.VISION if args.mmproj else ModelType.TEXT
+ model_architecture = get_model_architecture(dir_model, model_type)
+ logger.info(f"Model architecture: {model_architecture}")
try:
model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type)
except NotImplementedError: