]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metadata: Detailed Dataset Authorship Metadata (#8875)
authorBrian <redacted>
Wed, 13 Nov 2024 10:10:38 +0000 (21:10 +1100)
committerGitHub <redacted>
Wed, 13 Nov 2024 10:10:38 +0000 (21:10 +1100)
Converter script can now read these two fields as a detailed base model and dataset source.
This was done so that it will be easier for Hugging Face to integrate detailed metadata as needed.

 -  base_model_sources (List[dict], optional)
 -  dataset_sources (List[dict], optional)

Dataset now represented as:

 - general.dataset.count
 - general.dataset.{id}.name
 - general.dataset.{id}.author
 - general.dataset.{id}.version
 - general.dataset.{id}.organization
 - general.dataset.{id}.description
 - general.dataset.{id}.url
 - general.dataset.{id}.doi
 - general.dataset.{id}.uuid
 - general.dataset.{id}.repo_url

This also adds to base model these metadata:

 - general.base_model.{id}.description

examples/convert_legacy_llama.py
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/metadata.py
gguf-py/tests/test_metadata.py

index 9ab9ab06edf8fa7eb32b72701e57ad8804bfe3c3..c4ec5c524e9b13063bcfc74b2bf1a7f8d17c8a36 100755 (executable)
@@ -840,6 +840,8 @@ class OutputFile:
                         self.gguf.add_base_model_version(key, base_model_entry["version"])
                     if "organization" in base_model_entry:
                         self.gguf.add_base_model_organization(key, base_model_entry["organization"])
+                    if "description" in base_model_entry:
+                        self.gguf.add_base_model_description(key, base_model_entry["description"])
                     if "url" in base_model_entry:
                         self.gguf.add_base_model_url(key, base_model_entry["url"])
                     if "doi" in base_model_entry:
@@ -849,12 +851,32 @@ class OutputFile:
                     if "repo_url" in base_model_entry:
                         self.gguf.add_base_model_repo_url(key, base_model_entry["repo_url"])
 
+            if metadata.datasets is not None:
+                self.gguf.add_dataset_count(len(metadata.datasets))
+                for key, dataset_entry in enumerate(metadata.datasets):
+                    if "name" in dataset_entry:
+                        self.gguf.add_dataset_name(key, dataset_entry["name"])
+                    if "author" in dataset_entry:
+                        self.gguf.add_dataset_author(key, dataset_entry["author"])
+                    if "version" in dataset_entry:
+                        self.gguf.add_dataset_version(key, dataset_entry["version"])
+                    if "organization" in dataset_entry:
+                        self.gguf.add_dataset_organization(key, dataset_entry["organization"])
+                    if "description" in dataset_entry:
+                        self.gguf.add_dataset_description(key, dataset_entry["description"])
+                    if "url" in dataset_entry:
+                        self.gguf.add_dataset_url(key, dataset_entry["url"])
+                    if "doi" in dataset_entry:
+                        self.gguf.add_dataset_doi(key, dataset_entry["doi"])
+                    if "uuid" in dataset_entry:
+                        self.gguf.add_dataset_uuid(key, dataset_entry["uuid"])
+                    if "repo_url" in dataset_entry:
+                        self.gguf.add_dataset_repo_url(key, dataset_entry["repo_url"])
+
             if metadata.tags is not None:
                 self.gguf.add_tags(metadata.tags)
             if metadata.languages is not None:
                 self.gguf.add_languages(metadata.languages)
-            if metadata.datasets is not None:
-                self.gguf.add_datasets(metadata.datasets)
 
     def add_meta_arch(self, params: Params) -> None:
         # Metadata About The Neural Architecture Itself
index 7ab08b036e527781edc5c67fb503c75b7edd50b6..bc2b649d1a83dc6763fc58ce9512fb8059e51629 100644 (file)
@@ -64,15 +64,27 @@ class Keys:
         BASE_MODEL_AUTHOR          = "general.base_model.{id}.author"
         BASE_MODEL_VERSION         = "general.base_model.{id}.version"
         BASE_MODEL_ORGANIZATION    = "general.base_model.{id}.organization"
+        BASE_MODEL_DESCRIPTION     = "general.base_model.{id}.description"
         BASE_MODEL_URL             = "general.base_model.{id}.url" # Model Website/Paper
         BASE_MODEL_DOI             = "general.base_model.{id}.doi"
         BASE_MODEL_UUID            = "general.base_model.{id}.uuid"
         BASE_MODEL_REPO_URL        = "general.base_model.{id}.repo_url" # Model Source Repository (git/svn/etc...)
 
+        # Dataset Source
+        DATASET_COUNT           = "general.dataset.count"
+        DATASET_NAME            = "general.dataset.{id}.name"
+        DATASET_AUTHOR          = "general.dataset.{id}.author"
+        DATASET_VERSION         = "general.dataset.{id}.version"
+        DATASET_ORGANIZATION    = "general.dataset.{id}.organization"
+        DATASET_DESCRIPTION     = "general.dataset.{id}.description"
+        DATASET_URL             = "general.dataset.{id}.url" # Model Website/Paper
+        DATASET_DOI             = "general.dataset.{id}.doi"
+        DATASET_UUID            = "general.dataset.{id}.uuid"
+        DATASET_REPO_URL        = "general.dataset.{id}.repo_url" # Model Source Repository (git/svn/etc...)
+
         # Array based KV stores
         TAGS                       = "general.tags"
         LANGUAGES                  = "general.languages"
-        DATASETS                   = "general.datasets"
 
     class LLM:
         VOCAB_SIZE                        = "{arch}.vocab_size"
index 0d8d8a0b087e94d763b8aa71c2fae9bf1b5df4a0..7a55d129653624db4893b4ce7945f43dc48afd21 100644 (file)
@@ -568,6 +568,9 @@ class GGUFWriter:
     def add_base_model_organization(self, source_id: int, organization: str) -> None:
         self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization)
 
+    def add_base_model_description(self, source_id: int, description: str) -> None:
+        self.add_string(Keys.General.BASE_MODEL_DESCRIPTION.format(id=source_id), description)
+
     def add_base_model_url(self, source_id: int, url: str) -> None:
         self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url)
 
@@ -580,15 +583,42 @@ class GGUFWriter:
     def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None:
         self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url)
 
+    def add_dataset_count(self, source_count: int) -> None:
+        self.add_uint32(Keys.General.DATASET_COUNT, source_count)
+
+    def add_dataset_name(self, source_id: int, name: str) -> None:
+        self.add_string(Keys.General.DATASET_NAME.format(id=source_id), name)
+
+    def add_dataset_author(self, source_id: int, author: str) -> None:
+        self.add_string(Keys.General.DATASET_AUTHOR.format(id=source_id), author)
+
+    def add_dataset_version(self, source_id: int, version: str) -> None:
+        self.add_string(Keys.General.DATASET_VERSION.format(id=source_id), version)
+
+    def add_dataset_organization(self, source_id: int, organization: str) -> None:
+        self.add_string(Keys.General.DATASET_ORGANIZATION.format(id=source_id), organization)
+
+    def add_dataset_description(self, source_id: int, description: str) -> None:
+        self.add_string(Keys.General.DATASET_DESCRIPTION.format(id=source_id), description)
+
+    def add_dataset_url(self, source_id: int, url: str) -> None:
+        self.add_string(Keys.General.DATASET_URL.format(id=source_id), url)
+
+    def add_dataset_doi(self, source_id: int, doi: str) -> None:
+        self.add_string(Keys.General.DATASET_DOI.format(id=source_id), doi)
+
+    def add_dataset_uuid(self, source_id: int, uuid: str) -> None:
+        self.add_string(Keys.General.DATASET_UUID.format(id=source_id), uuid)
+
+    def add_dataset_repo_url(self, source_id: int, repo_url: str) -> None:
+        self.add_string(Keys.General.DATASET_REPO_URL.format(id=source_id), repo_url)
+
     def add_tags(self, tags: Sequence[str]) -> None:
         self.add_array(Keys.General.TAGS, tags)
 
     def add_languages(self, languages: Sequence[str]) -> None:
         self.add_array(Keys.General.LANGUAGES, languages)
 
-    def add_datasets(self, datasets: Sequence[str]) -> None:
-        self.add_array(Keys.General.DATASETS, datasets)
-
     def add_tensor_data_layout(self, layout: str) -> None:
         self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
 
index db318542a279b606e95ff51c82fd77615fce30b8..321cbcd4c5507773d3be1defc1d3ac15e1c2ca05 100644 (file)
@@ -41,7 +41,7 @@ class Metadata:
     base_models: Optional[list[dict]] = None
     tags: Optional[list[str]] = None
     languages: Optional[list[str]] = None
-    datasets: Optional[list[str]] = None
+    datasets: Optional[list[dict]] = None
 
     @staticmethod
     def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
@@ -91,9 +91,11 @@ class Metadata:
         # Base Models is received here as an array of models
         metadata.base_models     = metadata_override.get("general.base_models",        metadata.base_models)
 
+        # Datasets is received here as an array of datasets
+        metadata.datasets        = metadata_override.get("general.datasets",           metadata.datasets)
+
         metadata.tags            = metadata_override.get(Keys.General.TAGS,            metadata.tags)
         metadata.languages       = metadata_override.get(Keys.General.LANGUAGES,       metadata.languages)
-        metadata.datasets        = metadata_override.get(Keys.General.DATASETS,        metadata.datasets)
 
         # Direct Metadata Override (via direct cli argument)
         if model_name is not None:
@@ -346,12 +348,12 @@ class Metadata:
             use_model_card_metadata("author", "model_creator")
             use_model_card_metadata("basename", "model_type")
 
-            if "base_model" in model_card:
+            if "base_model" in model_card or "base_models" in model_card or "base_model_sources" in model_card:
                 # This represents the parent models that this is based on
                 # Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
                 # Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
                 metadata_base_models = []
-                base_model_value = model_card.get("base_model", None)
+                base_model_value = model_card.get("base_model", model_card.get("base_models", model_card.get("base_model_sources", None)))
 
                 if base_model_value is not None:
                     if isinstance(base_model_value, str):
@@ -364,18 +366,106 @@ class Metadata:
 
                 for model_id in metadata_base_models:
                     # NOTE: model size of base model is assumed to be similar to the size of the current model
-                    model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
                     base_model = {}
-                    if model_full_name_component is not None:
-                        base_model["name"] = Metadata.id_to_title(model_full_name_component)
-                    if org_component is not None:
-                        base_model["organization"] = Metadata.id_to_title(org_component)
-                    if version is not None:
-                        base_model["version"] = version
-                    if org_component is not None and model_full_name_component is not None:
-                        base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
+                    if isinstance(model_id, str):
+                        if model_id.startswith("http://") or model_id.startswith("https://") or model_id.startswith("ssh://"):
+                            base_model["repo_url"] = model_id
+
+                            # Check if Hugging Face ID is present in URL
+                            if "huggingface.co" in model_id:
+                                match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", model_id)
+                                if match:
+                                    model_id_component = match.group(1)
+                                    model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id_component, total_params)
+
+                                    # Populate model dictionary with extracted components
+                                    if model_full_name_component is not None:
+                                        base_model["name"] = Metadata.id_to_title(model_full_name_component)
+                                    if org_component is not None:
+                                        base_model["organization"] = Metadata.id_to_title(org_component)
+                                    if version is not None:
+                                        base_model["version"] = version
+
+                        else:
+                            # Likely a Hugging Face ID
+                            model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
+
+                            # Populate model dictionary with extracted components
+                            if model_full_name_component is not None:
+                                base_model["name"] = Metadata.id_to_title(model_full_name_component)
+                            if org_component is not None:
+                                base_model["organization"] = Metadata.id_to_title(org_component)
+                            if version is not None:
+                                base_model["version"] = version
+                            if org_component is not None and model_full_name_component is not None:
+                                base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
+
+                    elif isinstance(model_id, dict):
+                        base_model = model_id
+
+                    else:
+                        logger.error(f"base model entry '{str(model_id)}' not in a known format")
+
                     metadata.base_models.append(base_model)
 
+            if "datasets" in model_card or "dataset" in model_card or "dataset_sources" in model_card:
+                # This represents the datasets that this was trained from
+                metadata_datasets = []
+                dataset_value = model_card.get("datasets", model_card.get("dataset", model_card.get("dataset_sources", None)))
+
+                if dataset_value is not None:
+                    if isinstance(dataset_value, str):
+                        metadata_datasets.append(dataset_value)
+                    elif isinstance(dataset_value, list):
+                        metadata_datasets.extend(dataset_value)
+
+                if metadata.datasets is None:
+                    metadata.datasets = []
+
+                for dataset_id in metadata_datasets:
+                    # NOTE: model size of base model is assumed to be similar to the size of the current model
+                    dataset = {}
+                    if isinstance(dataset_id, str):
+                        if dataset_id.startswith(("http://", "https://", "ssh://")):
+                            dataset["repo_url"] = dataset_id
+
+                            # Check if Hugging Face ID is present in URL
+                            if "huggingface.co" in dataset_id:
+                                match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", dataset_id)
+                                if match:
+                                    dataset_id_component = match.group(1)
+                                    dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id_component, total_params)
+
+                                    # Populate dataset dictionary with extracted components
+                                    if dataset_name_component is not None:
+                                        dataset["name"] = Metadata.id_to_title(dataset_name_component)
+                                    if org_component is not None:
+                                        dataset["organization"] = Metadata.id_to_title(org_component)
+                                    if version is not None:
+                                        dataset["version"] = version
+
+                        else:
+                            # Likely a Hugging Face ID
+                            dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id, total_params)
+
+                            # Populate dataset dictionary with extracted components
+                            if dataset_name_component is not None:
+                                dataset["name"] = Metadata.id_to_title(dataset_name_component)
+                            if org_component is not None:
+                                dataset["organization"] = Metadata.id_to_title(org_component)
+                            if version is not None:
+                                dataset["version"] = version
+                            if org_component is not None and dataset_name_component is not None:
+                                dataset["repo_url"] = f"https://huggingface.co/{org_component}/{dataset_name_component}"
+
+                    elif isinstance(dataset_id, dict):
+                        dataset = dataset_id
+
+                    else:
+                        logger.error(f"dataset entry '{str(dataset_id)}' not in a known format")
+
+                    metadata.datasets.append(dataset)
+
             use_model_card_metadata("license", "license")
             use_model_card_metadata("license_name", "license_name")
             use_model_card_metadata("license_link", "license_link")
@@ -386,9 +476,6 @@ class Metadata:
             use_array_model_card_metadata("languages", "languages")
             use_array_model_card_metadata("languages", "language")
 
-            use_array_model_card_metadata("datasets", "datasets")
-            use_array_model_card_metadata("datasets", "dataset")
-
         # Hugging Face Parameter Heuristics
         ####################################
 
@@ -493,6 +580,8 @@ class Metadata:
                     gguf_writer.add_base_model_version(key, base_model_entry["version"])
                 if "organization" in base_model_entry:
                     gguf_writer.add_base_model_organization(key, base_model_entry["organization"])
+                if "description" in base_model_entry:
+                    gguf_writer.add_base_model_description(key, base_model_entry["description"])
                 if "url" in base_model_entry:
                     gguf_writer.add_base_model_url(key, base_model_entry["url"])
                 if "doi" in base_model_entry:
@@ -502,9 +591,29 @@ class Metadata:
                 if "repo_url" in base_model_entry:
                     gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"])
 
+        if self.datasets is not None:
+            gguf_writer.add_dataset_count(len(self.datasets))
+            for key, dataset_entry in enumerate(self.datasets):
+                if "name" in dataset_entry:
+                    gguf_writer.add_dataset_name(key, dataset_entry["name"])
+                if "author" in dataset_entry:
+                    gguf_writer.add_dataset_author(key, dataset_entry["author"])
+                if "version" in dataset_entry:
+                    gguf_writer.add_dataset_version(key, dataset_entry["version"])
+                if "organization" in dataset_entry:
+                    gguf_writer.add_dataset_organization(key, dataset_entry["organization"])
+                if "description" in dataset_entry:
+                    gguf_writer.add_dataset_description(key, dataset_entry["description"])
+                if "url" in dataset_entry:
+                    gguf_writer.add_dataset_url(key, dataset_entry["url"])
+                if "doi" in dataset_entry:
+                    gguf_writer.add_dataset_doi(key, dataset_entry["doi"])
+                if "uuid" in dataset_entry:
+                    gguf_writer.add_dataset_uuid(key, dataset_entry["uuid"])
+                if "repo_url" in dataset_entry:
+                    gguf_writer.add_dataset_repo_url(key, dataset_entry["repo_url"])
+
         if self.tags is not None:
             gguf_writer.add_tags(self.tags)
         if self.languages is not None:
             gguf_writer.add_languages(self.languages)
-        if self.datasets is not None:
-            gguf_writer.add_datasets(self.datasets)
index 81a2a30ae60f401129e68f107695c66c2fbce214..40d484f4eaa9d0c1270a5b70dc2e5153c6ed8d9f 100755 (executable)
@@ -182,8 +182,43 @@ class TestMetadataMethod(unittest.TestCase):
         expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': '14-v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}]
         expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl']
         expect.languages=['en']
-        expect.datasets=['teknium/OpenHermes-2.5']
+        expect.datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]
+        self.assertEqual(got, expect)
 
+        # Base Model spec is inferred from model id
+        model_card = {'base_models': 'teknium/OpenHermes-2.5'}
+        expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
+        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
+        self.assertEqual(got, expect)
+
+        # Base Model spec is only url
+        model_card = {'base_models': ['https://huggingface.co/teknium/OpenHermes-2.5']}
+        expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
+        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
+        self.assertEqual(got, expect)
+
+        # Base Model spec is given directly
+        model_card = {'base_models': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]}
+        expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
+        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
+        self.assertEqual(got, expect)
+
+        # Dataset spec is inferred from model id
+        model_card = {'datasets': 'teknium/OpenHermes-2.5'}
+        expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
+        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
+        self.assertEqual(got, expect)
+
+        # Dataset spec is only url
+        model_card = {'datasets': ['https://huggingface.co/teknium/OpenHermes-2.5']}
+        expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
+        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
+        self.assertEqual(got, expect)
+
+        # Dataset spec is given directly
+        model_card = {'datasets': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]}
+        expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
+        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
         self.assertEqual(got, expect)
 
     def test_apply_metadata_heuristic_from_hf_parameters(self):