]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
gguf-py : fix some metadata name extraction edge cases (#8591)
authorcompilade <redacted>
Sun, 21 Jul 2024 01:58:49 +0000 (21:58 -0400)
committerGitHub <redacted>
Sun, 21 Jul 2024 01:58:49 +0000 (21:58 -0400)
* gguf-py : fix some metadata name extraction edge cases

* convert_lora : use the lora dir for the model card path

* gguf-py : more metadata edge cases fixes

Multiple finetune versions are now joined together,
and the removal of the basename annotation on trailing versions
is more robust.

* gguf-py : add more name metadata extraction tests

* convert_lora : fix default filename

The default filename was previously hardcoded.

* convert_hf : Model.fname_out can no longer be None

* gguf-py : do not use title case for naming convention

Some models use acronyms in lowercase,
which can't be title-cased like other words,
so it's best to simply use the same case
as in the original model name.

Note that the size label still has an uppercased suffix
to make it distinguishable from the context size of a finetune.

convert_hf_to_gguf.py
convert_lora_to_gguf.py
gguf-py/gguf/metadata.py
gguf-py/gguf/utility.py
gguf-py/tests/test_metadata.py

index fba8dbbedebbd2aaedcaff8586c9cfeb01f377ed..139a92801fabe6bb441539d810145e25a9401b9a 100755 (executable)
@@ -48,7 +48,7 @@ class Model:
 
     dir_model: Path
     ftype: gguf.LlamaFileType
-    fname_out: Path | None
+    fname_out: Path
     is_big_endian: bool
     endianess: gguf.GGUFEndian
     use_temp_file: bool
@@ -62,11 +62,12 @@ class Model:
     gguf_writer: gguf.GGUFWriter
     model_name: str | None
     metadata_override: Path | None
+    dir_model_card: Path
 
     # subclasses should define this!
     model_arch: gguf.MODEL_ARCH
 
-    def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path | None, is_big_endian: bool = False,
+    def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
                  use_temp_file: bool = False, eager: bool = False,
                  metadata_override: Path | None = None, model_name: str | None = None,
                  split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
@@ -90,6 +91,7 @@ class Model:
         self.tensor_names = None
         self.metadata_override = metadata_override
         self.model_name = model_name
+        self.dir_model_card = dir_model  # overridden in convert_lora_to_gguf.py
 
         # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
         if self.ftype == gguf.LlamaFileType.GUESSED:
@@ -345,7 +347,7 @@ class Model:
 
         total_params, shared_params, expert_params, expert_count = self.gguf_writer.get_total_parameter_count()
 
-        self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model, self.model_name, total_params)
+        self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model_card, self.model_name, total_params)
 
         # Fallback to model directory name if metadata name is still missing
         if self.metadata.name is None:
@@ -359,27 +361,22 @@ class Model:
         output_type: str = self.ftype.name.partition("_")[2]
 
         # Filename Output
-        # Note: `not is_dir()` is used because `.is_file()` will not detect
-        #       file template strings as it doesn't actually exist as a file
-        if self.fname_out is not None and not self.fname_out.is_dir():
-            # Output path is a custom defined templated filename
-
-            # Process templated file name with the output ftype, useful with the "auto" ftype
-            self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type)
-        else:
+        if self.fname_out.is_dir():
             # Generate default filename based on model specification and available metadata
             if not vocab_only:
                 fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, self.metadata.size_label, output_type, model_type="LoRA" if total_params < 0 else None)
             else:
                 fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, size_label=None, output_type=None, model_type="vocab")
 
-            # Check if preferred output directory path was provided
-            if self.fname_out is not None and self.fname_out.is_dir():
-                # output path is a directory
-                self.fname_out = self.fname_out / f"{fname_default}.gguf"
-            else:
-                # output in the same directory as the model by default
-                self.fname_out = self.dir_model / f"{fname_default}.gguf"
+            # Use the default filename
+            self.fname_out = self.fname_out / f"{fname_default}.gguf"
+        else:
+            # Output path is a custom defined templated filename
+            # Note: `not is_dir()` is used because `.is_file()` will not detect
+            #       file template strings as it doesn't actually exist as a file
+
+            # Process templated file name with the output ftype, useful with the "auto" ftype
+            self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type)
 
         self.set_type()
 
@@ -3634,10 +3631,10 @@ def main() -> None:
         logger.error("Error: Cannot use temp file when splitting")
         sys.exit(1)
 
-    fname_out = None
-
     if args.outfile is not None:
         fname_out = args.outfile
+    else:
+        fname_out = dir_model
 
     logger.info(f"Loading model: {dir_model.name}")
 
@@ -3668,7 +3665,6 @@ def main() -> None:
         else:
             logger.info("Exporting model...")
             model_instance.write()
-            assert model_instance.fname_out is not None
             out_path = f"{model_instance.fname_out.parent}{os.sep}" if is_split else model_instance.fname_out
             logger.info(f"Model successfully exported to {out_path}")
 
index 66e8da37cba7cf086f63065be70ee005e100babc..a88d0d4a978a9d8514a4a786abdce4ad0d27c299 100755 (executable)
@@ -290,7 +290,7 @@ if __name__ == '__main__':
         fname_out = args.outfile
     else:
         # output in the same directory as the model by default
-        fname_out = dir_lora / 'ggml-lora-{ftype}.gguf'
+        fname_out = dir_lora
 
     if os.path.exists(input_model):
         # lazy import load_file only if lora is in safetensors format.
@@ -304,12 +304,6 @@ if __name__ == '__main__':
     # load base model
     logger.info(f"Loading base model: {dir_base_model.name}")
     hparams = Model.load_hparams(dir_base_model)
-
-    with open(lora_config, "r") as f:
-        lparams: dict[str, Any] = json.load(f)
-
-    alpha: float = lparams["lora_alpha"]
-
     with torch.inference_mode():
         try:
             model_class = Model.from_model_architecture(hparams["architectures"][0])
@@ -320,12 +314,21 @@ if __name__ == '__main__':
         class LoraModel(model_class):
             model_arch = model_class.model_arch
 
+            lora_alpha: float
+
+            def __init__(self, *args, dir_lora_model: Path, lora_alpha: float, **kwargs):
+
+                super().__init__(*args, **kwargs)
+
+                self.dir_model_card = dir_lora_model
+                self.lora_alpha = float(lora_alpha)
+
             def set_type(self):
                 self.gguf_writer.add_type(gguf.GGUFType.ADAPTER)
                 self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
 
             def set_gguf_parameters(self):
-                self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, float(alpha))
+                self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
                 super().set_gguf_parameters()
 
             def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
@@ -368,6 +371,11 @@ if __name__ == '__main__':
                     yield (dest_name + ".lora_a", lora_a)
                     yield (dest_name + ".lora_b", lora_b)
 
+        with open(lora_config, "r") as f:
+            lparams: dict[str, Any] = json.load(f)
+
+        alpha: float = lparams["lora_alpha"]
+
         model_instance = LoraModel(
             dir_base_model,
             ftype,
@@ -376,6 +384,8 @@ if __name__ == '__main__':
             use_temp_file=False,
             eager=args.no_lazy,
             dry_run=args.dry_run,
+            dir_lora_model=dir_lora,
+            lora_alpha=alpha,
         )
 
         logger.info("Exporting model...")
index bac6ebfb3777a0edb6559afc7922fed8a7d9881e..15189f7177500868bff49c5f4a1d2f21a100f800 100644 (file)
@@ -54,6 +54,7 @@ class Metadata:
 
         model_card = Metadata.load_model_card(model_path)
         hf_params = Metadata.load_hf_parameters(model_path)
+        # TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter
 
         # heuristics
         metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
@@ -177,6 +178,12 @@ class Metadata:
             org_component = None
 
         name_parts: list[str] = model_full_name_component.split('-')
+
+        # Remove empty parts
+        for i in reversed(range(len(name_parts))):
+            if len(name_parts[i]) == 0:
+                del name_parts[i]
+
         name_types: list[
             set[Literal["basename", "size_label", "finetune", "version", "type"]]
         ] = [set() for _ in name_parts]
@@ -223,9 +230,19 @@ class Metadata:
                 name_parts[i] = part
             # Some easy to recognize finetune names
             elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE):
-                name_types[i].add("finetune")
-                if part.lower() == "lora":
-                    name_parts[i] = "LoRA"
+                if total_params < 0 and part.lower() == "lora":
+                    # ignore redundant "lora" in the finetune part when the output is a lora adapter
+                    name_types[i].add("type")
+                else:
+                    name_types[i].add("finetune")
+
+        # Ignore word-based size labels when there is at least a number-based one present
+        # TODO: should word-based size labels always be removed instead?
+        if any(c.isdecimal() for n, t in zip(name_parts, name_types) if "size_label" in t for c in n):
+            for n, t in zip(name_parts, name_types):
+                if "size_label" in t:
+                    if all(c.isalpha() for c in n):
+                        t.remove("size_label")
 
         at_start = True
         # Find the basename through the annotated name
@@ -240,18 +257,18 @@ class Metadata:
 
         # Remove the basename annotation from trailing version
         for part, t in zip(reversed(name_parts), reversed(name_types)):
-            if "basename" in t:
-                if len(t) > 1:
-                    t.remove("basename")
+            if "basename" in t and len(t) > 1:
+                t.remove("basename")
             else:
                 break
 
         basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None
-        size_label = "-".join(s for s, t in zip(name_parts, name_types) if "size_label" in t) or None
+        # Deduplicate size labels using order-preserving 'dict' ('set' seems to sort the keys)
+        size_label = "-".join(dict.fromkeys(s for s, t in zip(name_parts, name_types) if "size_label" in t).keys()) or None
         finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None
         # TODO: should the basename version always be excluded?
-        # TODO: should multiple versions be joined together?
-        version = ([v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t] or [None])[-1]
+        # NOTE: multiple finetune versions are joined together
+        version = "-".join(v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t) or None
 
         if size_label is None and finetune is None and version is None:
             # Too ambiguous, output nothing
index ef76831b521eef0dd9baf28e2958052f0eadf0c9..40d59b75ee04ec6b46d219ea3be0b3a8fb8b3f35 100644 (file)
@@ -50,15 +50,15 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
     # Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention
 
     if base_name is not None:
-        name = base_name.strip().title().replace(' ', '-').replace('/', '-')
+        name = base_name.strip().replace(' ', '-').replace('/', '-')
     elif model_name is not None:
-        name = model_name.strip().title().replace(' ', '-').replace('/', '-')
+        name = model_name.strip().replace(' ', '-').replace('/', '-')
     else:
         name = "ggml-model"
 
     parameters = f"-{size_label}" if size_label is not None else ""
 
-    finetune = f"-{finetune_string.strip().title().replace(' ', '-')}" if finetune_string is not None else ""
+    finetune = f"-{finetune_string.strip().replace(' ', '-')}" if finetune_string is not None else ""
 
     version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""
 
index 3fac8218883f1f80a1a55767661779f92707a51b..81a2a30ae60f401129e68f107695c66c2fbce214 100755 (executable)
@@ -54,7 +54,7 @@ class TestMetadataMethod(unittest.TestCase):
         self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"),
                          ('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, '8B'))
 
-        # Can't detect all non standard form in a heuristically safe way... best to err in caution and output nothing...
+        # Non standard naming
         self.assertEqual(gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"),
                          ('Qwen1.5-MoE-A2.7B-Chat', None, 'Qwen1.5-MoE', 'Chat', None, 'A2.7B'))
 
@@ -71,7 +71,7 @@ class TestMetadataMethod(unittest.TestCase):
         self.assertEqual(gguf.Metadata.get_model_id_components("delphi-suite/stories-llama2-50k", 50 * 10**3),
                          ('stories-llama2-50k', 'delphi-suite', 'stories-llama2', None, None, '50K'))
 
-        # None standard and not easy to disambiguate
+        # Non standard and not easy to disambiguate
         self.assertEqual(gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"),
                          ('DeepSeek-Coder-V2-Lite-Instruct', None, 'DeepSeek-Coder-V2-Lite', 'Instruct', None, None))
 
@@ -123,6 +123,51 @@ class TestMetadataMethod(unittest.TestCase):
         self.assertEqual(gguf.Metadata.get_model_id_components("bigscience/bloom-7b1-petals"),
                          ('bloom-7b1-petals', 'bigscience', 'bloom', 'petals', None, '7.1B'))
 
+        # Ignore full-text size labels when there are number-based ones, and deduplicate size labels
+        self.assertEqual(gguf.Metadata.get_model_id_components("MaziyarPanahi/GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1"),
+                         ('GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1', 'MaziyarPanahi', 'GreenNode-mini', 'multilingual-v1olet-Mistral-Instruct', 'v0.1', '7B'))
+
+        # Instruct in a name without a size label
+        self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/Mistral-Nemo-Instruct-2407"),
+                         ('Mistral-Nemo-Instruct-2407', 'mistralai', 'Mistral-Nemo', 'Instruct', '2407', None))
+
+        # Non-obvious splitting relying on 'chat' keyword
+        self.assertEqual(gguf.Metadata.get_model_id_components("deepseek-ai/DeepSeek-V2-Chat-0628"),
+                         ('DeepSeek-V2-Chat-0628', 'deepseek-ai', 'DeepSeek-V2', 'Chat', '0628', None))
+
+        # Multiple versions
+        self.assertEqual(gguf.Metadata.get_model_id_components("OpenGVLab/Mini-InternVL-Chat-2B-V1-5"),
+                         ('Mini-InternVL-Chat-2B-V1-5', 'OpenGVLab', 'Mini-InternVL', 'Chat', 'V1-5', '2B'))
+
+        # TODO: DPO in the name
+        self.assertEqual(gguf.Metadata.get_model_id_components("jondurbin/bagel-dpo-2.8b-v0.2"),
+                         ('bagel-dpo-2.8b-v0.2', 'jondurbin', 'bagel-dpo', None, 'v0.2', '2.8B'))
+
+        # DPO in name, but can't be used for the finetune to keep 'LLaMA-3' in the basename
+        self.assertEqual(gguf.Metadata.get_model_id_components("voxmenthe/SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized"),
+                         ('SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized', 'voxmenthe', 'SFR-Iterative-DPO-LLaMA-3', 'R-unquantized', None, '8B'))
+
+        # Too ambiguous
+        # TODO: should "base" be a 'finetune' or 'size_label'?
+        # (in this case it should be a size label, but other models use it to signal that they are not finetuned)
+        self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Florence-2-base"),
+                         ('Florence-2-base', 'microsoft', None, None, None, None))
+
+        ## Invalid cases ##
+
+        # Start with a dash and has dashes in rows
+        self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/-Mistral--Nemo-Base-2407-"),
+                         ('-Mistral--Nemo-Base-2407-', 'mistralai', 'Mistral-Nemo-Base', None, '2407', None))
+
+        ## LoRA ##
+
+        self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B"),
+                         ('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration-LoRA', None, '8B'))
+
+        # Negative size --> output is a LoRA adaper --> prune "LoRA" out of the name to avoid redundancy with the suffix
+        self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B", -1234),
+                         ('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration', None, '8B'))
+
     def test_apply_metadata_heuristic_from_model_card(self):
         model_card = {
             'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],
@@ -134,7 +179,7 @@ class TestMetadataMethod(unittest.TestCase):
         }
         got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
         expect = gguf.Metadata()
-        expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': 'v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}]
+        expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': '14-v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}]
         expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl']
         expect.languages=['en']
         expect.datasets=['teknium/OpenHermes-2.5']