class TextModel(ModelBase):
+ model_type = ModelType.TEXT
+ hf_arch: str
+
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
+ self.hf_arch = get_model_architecture(self.hparams, self.model_type)
if "text_config" in self.hparams:
# move the text_config to the root level
if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.ADD_EOS)) is not None:
self.gguf_writer.add_add_eos_token(field.parts[-1].tolist()[0])
+ def _try_set_pooling_type(self) -> None:
+ # get pooling path
+ pooling_path = None
+ module_path = self.dir_model / "modules.json"
+ if module_path.is_file():
+ with open(module_path, encoding="utf-8") as f:
+ modules = json.load(f)
+ for mod in modules:
+ if mod["type"] == "sentence_transformers.models.Pooling":
+ pooling_path = mod["path"]
+ break
+
+ # get pooling type
+ if pooling_path is not None:
+ with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
+ pooling = json.load(f)
+ if pooling["pooling_mode_mean_tokens"]:
+ pooling_type = gguf.PoolingType.MEAN
+ elif pooling["pooling_mode_cls_token"]:
+ pooling_type = gguf.PoolingType.CLS
+ elif pooling["pooling_mode_lasttoken"]:
+ pooling_type = gguf.PoolingType.LAST
+ else:
+ raise NotImplementedError("Only MEAN, CLS, and LAST pooling types supported")
+ self.gguf_writer.add_pooling_type(pooling_type)
+
class VisionModel(ModelBase):
+ model_type = ModelType.VISION
model_arch = gguf.MODEL_ARCH.CLIP_VISION
- n_text_embd = 0
preprocessor_config: dict[str, Any]
global_config: dict[str, Any]
self.gguf_writer.add_file_type(self.ftype)
-@ModelBase.register("Qwen2ForCausalLM")
+@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM")
class Qwen2Model(TextModel):
model_arch = gguf.MODEL_ARCH.QWEN2
def set_gguf_parameters(self):
super().set_gguf_parameters()
+ self._try_set_pooling_type()
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"].get("type") == "yarn":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ if self.hf_arch == "Qwen2Model":
+ name = f"model.{name}" # map to Qwen2ForCausalLM tensors
+ yield from super().modify_tensors(data_torch, name, bid)
+
@ModelBase.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
class Qwen2VLModel(TextModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_causal_attention(False)
-
- # get pooling path
- pooling_path = None
- module_path = self.dir_model / "modules.json"
- if module_path.is_file():
- with open(module_path, encoding="utf-8") as f:
- modules = json.load(f)
- for mod in modules:
- if mod["type"] == "sentence_transformers.models.Pooling":
- pooling_path = mod["path"]
- break
-
- # get pooling type
- if pooling_path is not None:
- with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
- pooling = json.load(f)
- if pooling["pooling_mode_mean_tokens"]:
- pooling_type = gguf.PoolingType.MEAN
- elif pooling["pooling_mode_cls_token"]:
- pooling_type = gguf.PoolingType.CLS
- else:
- raise NotImplementedError("Only MEAN and CLS pooling types supported")
- self.gguf_writer.add_pooling_type(pooling_type)
+ self._try_set_pooling_type()
def set_vocab(self):
tokens, toktypes, tokpre = self.get_vocab_base()
return n
-def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any = None) -> str:
- hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams
+def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str:
text_config = hparams.get("text_config", {})
vision_config = hparams.get("vision_config", {})
arch = hparams["architectures"][0]
with torch.inference_mode():
output_type = ftype_map[args.outtype]
model_type = ModelType.VISION if args.mmproj else ModelType.TEXT
- model_architecture = get_model_architecture(dir_model, model_type)
+ hparams = ModelBase.load_hparams(dir_model)
+ model_architecture = get_model_architecture(hparams, model_type)
logger.info(f"Model architecture: {model_architecture}")
try:
model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type)