]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama-model : support Qwen2 embedding models and pooling_mode_lasttoken (#13245)
authorJared Van Bortel <redacted>
Fri, 2 May 2025 15:42:30 +0000 (11:42 -0400)
committerGitHub <redacted>
Fri, 2 May 2025 15:42:30 +0000 (11:42 -0400)
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
src/llama-model.cpp

index 7a7c1858ec2f04af083bcb31ccdeb12887ac449a..0862099da43eec9cb412011a489b4dd3246676cd 100755 (executable)
@@ -455,8 +455,12 @@ class ModelBase:
 
 
 class TextModel(ModelBase):
+    model_type = ModelType.TEXT
+    hf_arch: str
+
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
+        self.hf_arch = get_model_architecture(self.hparams, self.model_type)
 
         if "text_config" in self.hparams:
             # move the text_config to the root level
@@ -1075,10 +1079,36 @@ class TextModel(ModelBase):
         if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.ADD_EOS)) is not None:
             self.gguf_writer.add_add_eos_token(field.parts[-1].tolist()[0])
 
+    def _try_set_pooling_type(self) -> None:
+        # get pooling path
+        pooling_path = None
+        module_path = self.dir_model / "modules.json"
+        if module_path.is_file():
+            with open(module_path, encoding="utf-8") as f:
+                modules = json.load(f)
+            for mod in modules:
+                if mod["type"] == "sentence_transformers.models.Pooling":
+                    pooling_path = mod["path"]
+                    break
+
+        # get pooling type
+        if pooling_path is not None:
+            with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
+                pooling = json.load(f)
+            if pooling["pooling_mode_mean_tokens"]:
+                pooling_type = gguf.PoolingType.MEAN
+            elif pooling["pooling_mode_cls_token"]:
+                pooling_type = gguf.PoolingType.CLS
+            elif pooling["pooling_mode_lasttoken"]:
+                pooling_type = gguf.PoolingType.LAST
+            else:
+                raise NotImplementedError("Only MEAN, CLS, and LAST pooling types supported")
+            self.gguf_writer.add_pooling_type(pooling_type)
+
 
 class VisionModel(ModelBase):
+    model_type = ModelType.VISION
     model_arch = gguf.MODEL_ARCH.CLIP_VISION
-    n_text_embd = 0
     preprocessor_config: dict[str, Any]
     global_config: dict[str, Any]
 
@@ -2542,7 +2572,7 @@ class QwenModel(TextModel):
         self.gguf_writer.add_file_type(self.ftype)
 
 
-@ModelBase.register("Qwen2ForCausalLM")
+@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM")
 class Qwen2Model(TextModel):
     model_arch = gguf.MODEL_ARCH.QWEN2
 
@@ -2554,12 +2584,18 @@ class Qwen2Model(TextModel):
 
     def set_gguf_parameters(self):
         super().set_gguf_parameters()
+        self._try_set_pooling_type()
         if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
             if self.hparams["rope_scaling"].get("type") == "yarn":
                 self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
                 self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
                 self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
 
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        if self.hf_arch == "Qwen2Model":
+            name = f"model.{name}"  # map to Qwen2ForCausalLM tensors
+        yield from super().modify_tensors(data_torch, name, bid)
+
 
 @ModelBase.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
 class Qwen2VLModel(TextModel):
@@ -3396,29 +3432,7 @@ class BertModel(TextModel):
     def set_gguf_parameters(self):
         super().set_gguf_parameters()
         self.gguf_writer.add_causal_attention(False)
-
-        # get pooling path
-        pooling_path = None
-        module_path = self.dir_model / "modules.json"
-        if module_path.is_file():
-            with open(module_path, encoding="utf-8") as f:
-                modules = json.load(f)
-            for mod in modules:
-                if mod["type"] == "sentence_transformers.models.Pooling":
-                    pooling_path = mod["path"]
-                    break
-
-        # get pooling type
-        if pooling_path is not None:
-            with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
-                pooling = json.load(f)
-            if pooling["pooling_mode_mean_tokens"]:
-                pooling_type = gguf.PoolingType.MEAN
-            elif pooling["pooling_mode_cls_token"]:
-                pooling_type = gguf.PoolingType.CLS
-            else:
-                raise NotImplementedError("Only MEAN and CLS pooling types supported")
-            self.gguf_writer.add_pooling_type(pooling_type)
+        self._try_set_pooling_type()
 
     def set_vocab(self):
         tokens, toktypes, tokpre = self.get_vocab_base()
@@ -5962,8 +5976,7 @@ def split_str_to_n_bytes(split_str: str) -> int:
     return n
 
 
-def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any = None) -> str:
-    hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams
+def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str:
     text_config = hparams.get("text_config", {})
     vision_config = hparams.get("vision_config", {})
     arch = hparams["architectures"][0]
@@ -6034,7 +6047,8 @@ def main() -> None:
     with torch.inference_mode():
         output_type = ftype_map[args.outtype]
         model_type = ModelType.VISION if args.mmproj else ModelType.TEXT
-        model_architecture = get_model_architecture(dir_model, model_type)
+        hparams = ModelBase.load_hparams(dir_model)
+        model_architecture = get_model_architecture(hparams, model_type)
         logger.info(f"Model architecture: {model_architecture}")
         try:
             model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type)
index 74e46c3ee0f9523f00609efb8ba91b546896677a..7dd7bb6d1b5d9392d4155dc7da0a8039bc193a01 100644 (file)
@@ -2033,6 +2033,8 @@ class PoolingType(IntEnum):
     NONE = 0
     MEAN = 1
     CLS  = 2
+    LAST = 3
+    RANK = 4
 
 
 class GGMLQuantizationType(IntEnum):
index e163de76a759696584537f70e815c81c68a37d9b..08d21301374a39111e2810b6717b6332f5006b92 100644 (file)
@@ -773,6 +773,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
             // fall through
         case LLM_ARCH_QWEN2:
             {
+                ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
                 switch (hparams.n_layer) {
                     case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break;