]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
convert-lora : make `--base` optional (#10110)
authorXuan Son Nguyen <redacted>
Sat, 2 Nov 2024 11:53:17 +0000 (12:53 +0100)
committerGitHub <redacted>
Sat, 2 Nov 2024 11:53:17 +0000 (12:53 +0100)
* convert-lora : make `--base` optional

* lint

* handle case where base_model_name_or_path is invalid

* do not include metadata from base model

* clarify unspecified --base

* add small comment [no ci]

* trigger ci

convert_hf_to_gguf.py
convert_lora_to_gguf.py

index a34dabe235a34caba90eb195990bac06d892e861..76ee6cef52ac055dfcafa3eab374d8871ef242c1 100755 (executable)
@@ -72,7 +72,8 @@ class Model:
     def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
                  use_temp_file: bool = False, eager: bool = False,
                  metadata_override: Path | None = None, model_name: str | None = None,
-                 split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
+                 split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
+                 small_first_shard: bool = False, hparams: dict[str, Any] | None = None):
         if type(self) is Model:
             raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
 
@@ -87,7 +88,7 @@ class Model:
         self.is_safetensors = len(self.part_names) > 0
         if not self.is_safetensors:
             self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
-        self.hparams = Model.load_hparams(self.dir_model)
+        self.hparams = Model.load_hparams(self.dir_model) if hparams is None else hparams
         self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
         self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
         self.tensor_names = None
@@ -1541,6 +1542,17 @@ class LlamaModel(Model):
             special_vocab._set_special_token("eot",    32010)
             special_vocab.add_to_gguf(self.gguf_writer)
 
+        tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
+        if tokenizer_config_file.is_file():
+            with open(tokenizer_config_file, "r", encoding="utf-8") as f:
+                tokenizer_config_json = json.load(f)
+                if "add_prefix_space" in tokenizer_config_json:
+                    self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
+
+        # Apply to granite small models only
+        if self.hparams.get("vocab_size", 32000) == 49152:
+            self.gguf_writer.add_add_bos_token(False)
+
     def set_gguf_parameters(self):
         super().set_gguf_parameters()
         hparams = self.hparams
@@ -1557,17 +1569,6 @@ class LlamaModel(Model):
                 self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
                 self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
 
-        tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
-        if tokenizer_config_file.is_file():
-            with open(tokenizer_config_file, "r", encoding="utf-8") as f:
-                tokenizer_config_json = json.load(f)
-                if "add_prefix_space" in tokenizer_config_json:
-                    self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
-
-        # Apply to granite small models only
-        if self.hparams.get("vocab_size", 32000) == 49152:
-            self.gguf_writer.add_add_bos_token(False)
-
     @staticmethod
     def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
         if n_head_kv is not None and n_head != n_head_kv:
index 915e218366929fcb33406e498fd094b0294cd8d4..ed1014cae075aa1960e8a5f53ef38b4b76b6f2a9 100755 (executable)
@@ -12,6 +12,7 @@ import json
 from math import prod
 from pathlib import Path
 from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
+from transformers import AutoConfig
 
 import torch
 
@@ -256,8 +257,8 @@ def parse_args() -> argparse.Namespace:
         help="only print out what will be done, without writing any new files",
     )
     parser.add_argument(
-        "--base", type=Path, required=True,
-        help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required",
+        "--base", type=Path,
+        help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required. If base model is unspecified, it will be loaded from Hugging Face hub based on the adapter config",
     )
     parser.add_argument(
         "lora_path", type=Path,
@@ -267,6 +268,12 @@ def parse_args() -> argparse.Namespace:
     return parser.parse_args()
 
 
+def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
+    # normally, adapter does not come with base model config, we need to load it from AutoConfig
+    config = AutoConfig.from_pretrained(hf_model_id)
+    return config.to_dict()
+
+
 if __name__ == '__main__':
     args = parse_args()
     logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
@@ -281,7 +288,7 @@ if __name__ == '__main__':
 
     ftype = ftype_map[args.outtype]
 
-    dir_base_model: Path = args.base
+    dir_base_model: Path | None = args.base
     dir_lora: Path = args.lora_path
     lora_config = dir_lora / "adapter_config.json"
     input_model = dir_lora / "adapter_model.safetensors"
@@ -301,9 +308,29 @@ if __name__ == '__main__':
         input_model = os.path.join(dir_lora, "adapter_model.bin")
         lora_model = torch.load(input_model, map_location="cpu", weights_only=True)
 
+    # load LoRA config
+    with open(lora_config, "r") as f:
+        lparams: dict[str, Any] = json.load(f)
+
     # load base model
-    logger.info(f"Loading base model: {dir_base_model.name}")
-    hparams = Model.load_hparams(dir_base_model)
+    if dir_base_model is None:
+        if "base_model_name_or_path" in lparams:
+            model_id = lparams["base_model_name_or_path"]
+            logger.info(f"Loading base model from Hugging Face: {model_id}")
+            try:
+                hparams = load_hparams_from_hf(model_id)
+            except OSError as e:
+                logger.error(f"Failed to load base model config: {e}")
+                logger.error("Please try downloading the base model and add its path to --base")
+                sys.exit(1)
+        else:
+            logger.error("'base_model_name_or_path' is not found in adapter_config.json")
+            logger.error("Base model config is required. Please download the base model and add its path to --base")
+            sys.exit(1)
+    else:
+        logger.info(f"Loading base model: {dir_base_model.name}")
+        hparams = Model.load_hparams(dir_base_model)
+
     with torch.inference_mode():
         try:
             model_class = Model.from_model_architecture(hparams["architectures"][0])
@@ -323,13 +350,15 @@ if __name__ == '__main__':
                 self.dir_model_card = dir_lora_model
                 self.lora_alpha = float(lora_alpha)
 
+            def set_vocab(self):
+                pass
+
             def set_type(self):
                 self.gguf_writer.add_type(gguf.GGUFType.ADAPTER)
                 self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
 
             def set_gguf_parameters(self):
                 self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
-                super().set_gguf_parameters()
 
             def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
                 # Never add extra tensors (e.g. rope_freqs) for LoRA adapters
@@ -350,7 +379,7 @@ if __name__ == '__main__':
                         logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
                         if ".embed_tokens.weight" in name or ".lm_head.weight" in name:
                             logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning")
-                            logger.error("Hint: if you are using TRL, make sure not to call setup_chat_format()")
+                            logger.error("Please refer to https://github.com/ggerganov/llama.cpp/pull/9948")
                         sys.exit(1)
 
                     if base_name in tensor_map:
@@ -384,9 +413,6 @@ if __name__ == '__main__':
                     yield (dest_name + ".lora_a", lora_a)
                     yield (dest_name + ".lora_b", lora_b)
 
-        with open(lora_config, "r") as f:
-            lparams: dict[str, Any] = json.load(f)
-
         alpha: float = lparams["lora_alpha"]
 
         model_instance = LoraModel(
@@ -399,6 +425,7 @@ if __name__ == '__main__':
             dry_run=args.dry_run,
             dir_lora_model=dir_lora,
             lora_alpha=alpha,
+            hparams=hparams,
         )
 
         logger.info("Exporting model...")