return parser.parse_args()
-def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
+def load_hparams_from_hf(hf_model_id: str) -> tuple[dict[str, Any], Path | None]:
+ from huggingface_hub import try_to_load_from_cache
+
# normally, adapter does not come with base model config, we need to load it from AutoConfig
config = AutoConfig.from_pretrained(hf_model_id)
- return config.to_dict()
+ cache_dir = try_to_load_from_cache(hf_model_id, "config.json")
+ cache_dir = Path(cache_dir).parent if isinstance(cache_dir, str) else None
+
+ return config.to_dict(), cache_dir
if __name__ == '__main__':
# load base model
if base_model_id is not None:
logger.info(f"Loading base model from Hugging Face: {base_model_id}")
- hparams = load_hparams_from_hf(base_model_id)
+ hparams, dir_base_model = load_hparams_from_hf(base_model_id)
elif dir_base_model is None:
if "base_model_name_or_path" in lparams:
model_id = lparams["base_model_name_or_path"]
logger.info(f"Loading base model from Hugging Face: {model_id}")
try:
- hparams = load_hparams_from_hf(model_id)
+ hparams, dir_base_model = load_hparams_from_hf(model_id)
except OSError as e:
logger.error(f"Failed to load base model config: {e}")
logger.error("Please try downloading the base model and add its path to --base")
dir_lora_model=dir_lora,
lora_alpha=alpha,
hparams=hparams,
+ remote_hf_model_id=base_model_id,
)
logger.info("Exporting model...")