]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
convert-*.py: GGUF Naming Convention Refactor and Metadata Override Refactor (#7499)
authorBrian <redacted>
Thu, 18 Jul 2024 10:40:15 +0000 (20:40 +1000)
committerGitHub <redacted>
Thu, 18 Jul 2024 10:40:15 +0000 (20:40 +1000)
Main thing is that the default output filename will take this form

{name}{parameters}{finetune}{version}{encoding}{kind}

In addition this add and remove some entries in the KV store and adds a metadata class with automatic heuristics capability to derive some values based on model card content

* No Change:
  - Internal GGUF Spec
    - `general.architecture`
    - `general.quantization_version`
    - `general.alignment`
    - `general.file_type`
  - General Model Details
    - `general.name`
    - `general.author`
    - `general.version`
    - `general.description`
  - Licensing details
    - `general.license`
  - Typically represents the converted GGUF repo (Unless made from scratch)
    - `general.url`
  - Model Source during conversion
    - `general.source.url`

* Removed:
  - Model Source during conversion
    - `general.source.huggingface.repository`

* Added:
  - General Model Details
    - `general.organization`
    - `general.finetune`
    - `general.basename`
    - `general.quantized_by`
    - `general.size_label`
  - Licensing details
    - `general.license.name`
    - `general.license.link`
  - Typically represents the converted GGUF repo (Unless made from scratch)
    - `general.doi`
    - `general.uuid`
    - `general.repo_url`
  - Model Source during conversion
    - `general.source.doi`
    - `general.source.uuid`
    - `general.source.repo_url`
  - Base Model Source
    - `general.base_model.count`
    - `general.base_model.{id}.name`
    - `general.base_model.{id}.author`
    - `general.base_model.{id}.version`
    - `general.base_model.{id}.organization`
    - `general.base_model.{id}.url` (Model Website/Paper)
    - `general.base_model.{id}.doi`
    - `general.base_model.{id}.uuid`
    - `general.base_model.{id}.repo_url` (Model Source Repository (git/svn/etc...))
  - Array based KV stores
    - `general.tags`
    - `general.languages`
    - `general.datasets`

---------

Co-authored-by: compilade <redacted>
Co-authored-by: Xuan Son Nguyen <redacted>
13 files changed:
convert_hf_to_gguf.py
convert_lora_to_gguf.py
examples/convert_legacy_llama.py
gguf-py/README.md
gguf-py/gguf/__init__.py
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/metadata.py [new file with mode: 0644]
gguf-py/gguf/utility.py [new file with mode: 0644]
gguf-py/pyproject.toml
gguf-py/tests/__init__.py [new file with mode: 0644]
gguf-py/tests/test_gguf.py [deleted file]
gguf-py/tests/test_metadata.py [new file with mode: 0755]

index c2aba909706d0a8fa7dbeca5346a3a79b9e6270c..769d49a8b6f0a8d11a67a057a2cd08806ce89da4 100755 (executable)
@@ -48,34 +48,38 @@ class Model:
 
     dir_model: Path
     ftype: gguf.LlamaFileType
+    fname_out: Path | None
     is_big_endian: bool
     endianess: gguf.GGUFEndian
     use_temp_file: bool
     lazy: bool
-    model_name: str | None
     part_names: list[str]
     is_safetensors: bool
     hparams: dict[str, Any]
     block_count: int
     tensor_map: gguf.TensorNameMap
     tensor_names: set[str] | None
-    fname_out: Path
     gguf_writer: gguf.GGUFWriter
+    model_name: str | None
+    metadata_override: Path | None
 
     # subclasses should define this!
     model_arch: gguf.MODEL_ARCH
 
-    def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool,
-                 model_name: str | None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
+    def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path | None, is_big_endian: bool = False,
+                 use_temp_file: bool = False, eager: bool = False,
+                 metadata_override: Path | None = None, model_name: str | None = None,
+                 split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
         if type(self) is Model:
             raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
+
         self.dir_model = dir_model
         self.ftype = ftype
+        self.fname_out = fname_out
         self.is_big_endian = is_big_endian
         self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
         self.use_temp_file = use_temp_file
         self.lazy = not eager
-        self.model_name = model_name
         self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
         self.is_safetensors = len(self.part_names) > 0
         if not self.is_safetensors:
@@ -84,6 +88,10 @@ class Model:
         self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
         self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
         self.tensor_names = None
+        self.metadata_override = metadata_override
+        self.model_name = model_name
+
+        # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
         if self.ftype == gguf.LlamaFileType.GUESSED:
             # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
             _, first_tensor = next(self.get_tensors())
@@ -93,10 +101,8 @@ class Model:
             else:
                 logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
                 self.ftype = gguf.LlamaFileType.MOSTLY_BF16
-        ftype_up: str = self.ftype.name.partition("_")[2].upper()
-        ftype_lw: str = ftype_up.lower()
-        # allow templating the file name with the output ftype, useful with the "auto" ftype
-        self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
+
+        # Configure GGUF Writer
         self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
                                            split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
 
@@ -193,7 +199,6 @@ class Model:
         return new_name
 
     def set_gguf_parameters(self):
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
         self.gguf_writer.add_block_count(self.block_count)
 
         if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
@@ -250,7 +255,7 @@ class Model:
 
         return False
 
-    def write_tensors(self):
+    def prepare_tensors(self):
         max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
 
         for name, data_torch in self.get_tensors():
@@ -333,9 +338,67 @@ class Model:
 
                 self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)
 
+    def set_type(self):
+        self.gguf_writer.add_type(gguf.GGUFType.MODEL)
+
+    def prepare_metadata(self, vocab_only: bool):
+
+        total_params, shared_params, expert_params, expert_count = self.gguf_writer.get_total_parameter_count()
+
+        self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model, self.model_name, total_params)
+
+        # Fallback to model directory name if metadata name is still missing
+        if self.metadata.name is None:
+            self.metadata.name = self.dir_model.name
+
+        # Generate parameter weight class (useful for leader boards) if not yet determined
+        if self.metadata.size_label is None and total_params > 0:
+            self.metadata.size_label = gguf.size_label(total_params, shared_params, expert_params, expert_count)
+
+        # Extract the encoding scheme from the file type name. e.g. 'gguf.LlamaFileType.MOSTLY_Q8_0' --> 'Q8_0'
+        output_type: str = self.ftype.name.partition("_")[2]
+
+        # Filename Output
+        # Note: `not is_dir()` is used because `.is_file()` will not detect
+        #       file template strings as it doesn't actually exist as a file
+        if self.fname_out is not None and not self.fname_out.is_dir():
+            # Output path is a custom defined templated filename
+
+            # Process templated file name with the output ftype, useful with the "auto" ftype
+            self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type)
+        else:
+            # Generate default filename based on model specification and available metadata
+            if not vocab_only:
+                fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, self.metadata.size_label, output_type, model_type="LoRA" if total_params < 0 else None)
+            else:
+                fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, size_label=None, output_type=None, model_type="vocab")
+
+            # Check if preferred output directory path was provided
+            if self.fname_out is not None and self.fname_out.is_dir():
+                # output path is a directory
+                self.fname_out = self.fname_out / f"{fname_default}.gguf"
+            else:
+                # output in the same directory as the model by default
+                self.fname_out = self.dir_model / f"{fname_default}.gguf"
+
+        self.set_type()
+
+        logger.info("Set meta model")
+        self.metadata.set_gguf_meta_model(self.gguf_writer)
+
+        logger.info("Set model parameters")
+        self.set_gguf_parameters()
+
+        logger.info("Set model tokenizer")
+        self.set_vocab()
+
+        logger.info("Set model quantization version")
+        self.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
+
     def write(self):
-        self.write_tensors()
-        self.gguf_writer.write_header_to_file(self.fname_out)
+        self.prepare_tensors()
+        self.prepare_metadata(vocab_only=False)
+        self.gguf_writer.write_header_to_file(path=self.fname_out)
         self.gguf_writer.write_kv_data_to_file()
         self.gguf_writer.write_tensors_to_file(progress=True)
         self.gguf_writer.close()
@@ -343,7 +406,9 @@ class Model:
     def write_vocab(self):
         if len(self.gguf_writer.tensors) != 1:
             raise ValueError('Splitting the vocabulary is not supported')
-        self.gguf_writer.write_header_to_file(self.fname_out)
+
+        self.prepare_metadata(vocab_only=True)
+        self.gguf_writer.write_header_to_file(path=self.fname_out)
         self.gguf_writer.write_kv_data_to_file()
         self.gguf_writer.close()
 
@@ -780,7 +845,6 @@ class GPTNeoXModel(Model):
     def set_gguf_parameters(self):
         block_count = self.hparams["num_hidden_layers"]
 
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
         self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
         self.gguf_writer.add_block_count(block_count)
@@ -836,7 +900,6 @@ class BloomModel(Model):
     model_arch = gguf.MODEL_ARCH.BLOOM
 
     def set_gguf_parameters(self):
-        self.gguf_writer.add_name("Bloom")
         n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
         n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
         self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
@@ -913,7 +976,6 @@ class MPTModel(Model):
 
     def set_gguf_parameters(self):
         block_count = self.hparams["n_layers"]
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
         self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
         self.gguf_writer.add_embedding_length(self.hparams["d_model"])
         self.gguf_writer.add_block_count(block_count)
@@ -952,7 +1014,6 @@ class OrionModel(Model):
         block_count = self.hparams["num_hidden_layers"]
         head_count = self.hparams["num_attention_heads"]
         head_count_kv = self.hparams.get("num_key_value_heads", head_count)
-        hf_repo = self.hparams.get("_name_or_path", "")
 
         ctx_length = 0
         if "max_sequence_length" in self.hparams:
@@ -965,8 +1026,6 @@ class OrionModel(Model):
             raise ValueError("gguf: can not find ctx length parameter.")
 
         self.gguf_writer.add_file_type(self.ftype)
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
-        self.gguf_writer.add_source_hf_repo(hf_repo)
         self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
         self.gguf_writer.add_context_length(ctx_length)
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
@@ -990,7 +1049,6 @@ class BaichuanModel(Model):
         block_count = self.hparams["num_hidden_layers"]
         head_count = self.hparams["num_attention_heads"]
         head_count_kv = self.hparams.get("num_key_value_heads", head_count)
-        hf_repo = self.hparams.get("_name_or_path", "")
 
         ctx_length = 0
         if "max_sequence_length" in self.hparams:
@@ -1002,8 +1060,6 @@ class BaichuanModel(Model):
         else:
             raise ValueError("gguf: can not find ctx length parameter.")
 
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
-        self.gguf_writer.add_source_hf_repo(hf_repo)
         self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
         self.gguf_writer.add_context_length(ctx_length)
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
@@ -1117,7 +1173,6 @@ class XverseModel(Model):
         block_count = self.hparams["num_hidden_layers"]
         head_count = self.hparams["num_attention_heads"]
         head_count_kv = self.hparams.get("num_key_value_heads", head_count)
-        hf_repo = self.hparams.get("_name_or_path", "")
 
         ctx_length = 0
         if "max_sequence_length" in self.hparams:
@@ -1129,8 +1184,6 @@ class XverseModel(Model):
         else:
             raise ValueError("gguf: can not find ctx length parameter.")
 
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
-        self.gguf_writer.add_source_hf_repo(hf_repo)
         self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
         self.gguf_writer.add_context_length(ctx_length)
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
@@ -1189,7 +1242,6 @@ class FalconModel(Model):
         if n_head_kv is None:
             n_head_kv = self.hparams.get("n_head_kv", 1)  # old name
 
-        self.gguf_writer.add_name("Falcon")
         self.gguf_writer.add_context_length(2048)  # not in config.json
         self.gguf_writer.add_tensor_data_layout("jploski")  # qkv tensor transform
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
@@ -1234,7 +1286,6 @@ class StarCoderModel(Model):
     def set_gguf_parameters(self):
         block_count = self.hparams["n_layer"]
 
-        self.gguf_writer.add_name("StarCoder")
         self.gguf_writer.add_context_length(self.hparams["n_positions"])
         self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
         self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
@@ -1269,7 +1320,6 @@ class RefactModel(Model):
 
         block_count = self.hparams["n_layer"]
 
-        self.gguf_writer.add_name("Refact")
         # refact uses Alibi. So this is from config.json which might be used by training.
         self.gguf_writer.add_context_length(self.hparams["n_positions"])
         self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
@@ -1324,7 +1374,6 @@ class StableLMModel(Model):
         hparams = self.hparams
         block_count = hparams["num_hidden_layers"]
 
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
         self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
         self.gguf_writer.add_embedding_length(hparams["hidden_size"])
         self.gguf_writer.add_block_count(block_count)
@@ -1386,8 +1435,8 @@ class StableLMModel(Model):
 
         return [(new_name, data_torch)]
 
-    def write_tensors(self):
-        super().write_tensors()
+    def prepare_tensors(self):
+        super().prepare_tensors()
 
         if self._q_norms is not None or self._k_norms is not None:
             # flatten two `list[dict[str, Tensor]]` into a single `list[str]`
@@ -1503,8 +1552,8 @@ class LlamaModel(Model):
 
         return [(self.map_tensor_name(name), data_torch)]
 
-    def write_tensors(self):
-        super().write_tensors()
+    def prepare_tensors(self):
+        super().prepare_tensors()
 
         if self._experts is not None:
             # flatten `list[dict[str, Tensor]]` into `list[str]`
@@ -1567,7 +1616,6 @@ class GrokModel(Model):
 
     def set_gguf_parameters(self):
         super().set_gguf_parameters()
-        self.gguf_writer.add_name("Grok")
 
     _experts: list[dict[str, Tensor]] | None = None
 
@@ -1616,7 +1664,6 @@ class DbrxModel(Model):
     def set_gguf_parameters(self):
         ffn_config = self.hparams["ffn_config"]
         attn_config = self.hparams["attn_config"]
-        self.gguf_writer.add_name(self.hparams["model_type"])
         self.gguf_writer.add_block_count(self.hparams["n_layers"])
 
         self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
@@ -1685,7 +1732,6 @@ class MiniCPMModel(Model):
 
     def set_gguf_parameters(self):
         block_count = self.hparams["num_hidden_layers"]
-        self.gguf_writer.add_name("MiniCPM")
         self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
         self.gguf_writer.add_block_count(block_count)
@@ -1755,7 +1801,6 @@ class QwenModel(Model):
         self._set_vocab_qwen()
 
     def set_gguf_parameters(self):
-        self.gguf_writer.add_name("Qwen")
         self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
         self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
@@ -1831,8 +1876,8 @@ class Qwen2MoeModel(Model):
 
         return [(self.map_tensor_name(name), data_torch)]
 
-    def write_tensors(self):
-        super().write_tensors()
+    def prepare_tensors(self):
+        super().prepare_tensors()
 
         if self._experts is not None:
             # flatten `list[dict[str, Tensor]]` into `list[str]`
@@ -1846,7 +1891,6 @@ class GPT2Model(Model):
     model_arch = gguf.MODEL_ARCH.GPT2
 
     def set_gguf_parameters(self):
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
         self.gguf_writer.add_block_count(self.hparams["n_layer"])
         self.gguf_writer.add_context_length(self.hparams["n_ctx"])
         self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
@@ -1889,7 +1933,6 @@ class Phi2Model(Model):
         n_embd = self.find_hparam(["hidden_size", "n_embd"])
         n_head = self.find_hparam(["num_attention_heads", "n_head"])
 
-        self.gguf_writer.add_name("Phi2")
         self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"]))
 
         self.gguf_writer.add_embedding_length(n_embd)
@@ -2011,7 +2054,6 @@ class Phi3MiniModel(Model):
         orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
         rope_dims = n_embd // n_head
 
-        self.gguf_writer.add_name("Phi3")
         self.gguf_writer.add_context_length(max_pos_embds)
         self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds)
         self.gguf_writer.add_embedding_length(n_embd)
@@ -2068,7 +2110,6 @@ class PlamoModel(Model):
         hparams = self.hparams
         block_count = hparams["num_hidden_layers"]
 
-        self.gguf_writer.add_name("PLaMo")
         self.gguf_writer.add_context_length(4096)  # not in config.json
         self.gguf_writer.add_embedding_length(hparams["hidden_size"])
         self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
@@ -2113,7 +2154,6 @@ class CodeShellModel(Model):
     def set_gguf_parameters(self):
         block_count = self.hparams["n_layer"]
 
-        self.gguf_writer.add_name("CodeShell")
         self.gguf_writer.add_context_length(self.hparams["n_positions"])
         self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
         self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
@@ -2272,7 +2312,6 @@ class InternLM2Model(Model):
         special_vocab.add_to_gguf(self.gguf_writer)
 
     def set_gguf_parameters(self):
-        self.gguf_writer.add_name("InternLM2")
         self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
         self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
         self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
@@ -2440,7 +2479,6 @@ class GemmaModel(Model):
         hparams = self.hparams
         block_count = hparams["num_hidden_layers"]
 
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
         self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
         self.gguf_writer.add_embedding_length(hparams["hidden_size"])
         self.gguf_writer.add_block_count(block_count)
@@ -2481,7 +2519,6 @@ class Gemma2Model(Model):
         hparams = self.hparams
         block_count = hparams["num_hidden_layers"]
 
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
         self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
         self.gguf_writer.add_embedding_length(hparams["hidden_size"])
         self.gguf_writer.add_block_count(block_count)
@@ -2556,7 +2593,6 @@ class MambaModel(Model):
         # Fail early for models which don't have a block expansion factor of 2
         assert d_inner == 2 * d_model
 
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
         self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
         self.gguf_writer.add_embedding_length(d_model)
         self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
@@ -2735,7 +2771,6 @@ class OpenELMModel(Model):
         assert self.block_count == len(self._num_query_heads)
         assert self.block_count == len(self._ffn_dims)
 
-        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
         self.gguf_writer.add_block_count(self.block_count)
         self.gguf_writer.add_context_length(self.hparams["max_context_length"])
         self.gguf_writer.add_embedding_length(n_embd)
@@ -2909,8 +2944,8 @@ class ArcticModel(Model):
 
         return [(self.map_tensor_name(name), data_torch)]
 
-    def write_tensors(self):
-        super().write_tensors()
+    def prepare_tensors(self):
+        super().prepare_tensors()
 
         if self._experts is not None:
             # flatten `list[dict[str, Tensor]]` into `list[str]`
@@ -2988,8 +3023,8 @@ class DeepseekV2Model(Model):
 
         return [(self.map_tensor_name(name), data_torch)]
 
-    def write_tensors(self):
-        super().write_tensors()
+    def prepare_tensors(self):
+        super().prepare_tensors()
 
         if self._experts is not None:
             # flatten `list[dict[str, Tensor]]` into `list[str]`
@@ -3107,7 +3142,6 @@ class T5Model(Model):
         self.gguf_writer.add_add_eos_token(True)
 
     def set_gguf_parameters(self):
-        self.gguf_writer.add_name("T5")
         if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
             logger.warning("Couldn't find context length in config.json, assuming default value of 512")
             n_ctx = 512
@@ -3181,7 +3215,6 @@ class JaisModel(Model):
         self._set_vocab_gpt2()
 
     def set_gguf_parameters(self):
-        self.gguf_writer.add_name(self.dir_model.name)
         self.gguf_writer.add_block_count(self.hparams["n_layer"])
         self.gguf_writer.add_context_length(self.hparams["n_positions"])
         self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
@@ -3227,8 +3260,8 @@ class JaisModel(Model):
 
         return tensors
 
-    def write_tensors(self):
-        super().write_tensors()
+    def prepare_tensors(self):
+        super().prepare_tensors()
         self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
 
 
@@ -3539,6 +3572,10 @@ def parse_args() -> argparse.Namespace:
         "--no-tensor-first-split", action="store_true",
         help="do not add tensors to the first split (disabled by default)"
     )
+    parser.add_argument(
+        "--metadata", type=Path,
+        help="Specify the path for an authorship metadata override file"
+    )
 
     return parser.parse_args()
 
@@ -3564,7 +3601,10 @@ def split_str_to_n_bytes(split_str: str) -> int:
 def main() -> None:
     args = parse_args()
 
-    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
+    if args.verbose:
+        logging.basicConfig(level=logging.DEBUG)
+    else:
+        logging.basicConfig(level=logging.INFO)
 
     dir_model = args.model
 
@@ -3585,37 +3625,33 @@ def main() -> None:
         logger.error("Error: Cannot use temp file when splitting")
         sys.exit(1)
 
+    fname_out = None
+
     if args.outfile is not None:
         fname_out = args.outfile
-    else:
-        # output in the same directory as the model by default
-        fname_out = dir_model / 'ggml-model-{ftype}.gguf'
 
     logger.info(f"Loading model: {dir_model.name}")
 
     hparams = Model.load_hparams(dir_model)
 
     with torch.inference_mode():
+        output_type = ftype_map[args.outtype]
+        model_architecture = hparams["architectures"][0]
+
         try:
-            model_class = Model.from_model_architecture(hparams["architectures"][0])
+            model_class = Model.from_model_architecture(model_architecture)
         except NotImplementedError:
-            logger.error(f"Model {hparams['architectures'][0]} is not supported")
+            logger.error(f"Model {model_architecture} is not supported")
             sys.exit(1)
 
-        model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file,
-                                     args.no_lazy, args.model_name, split_max_tensors=args.split_max_tensors,
+        model_instance = model_class(dir_model=dir_model, ftype=output_type, fname_out=fname_out,
+                                     is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
+                                     eager=args.no_lazy,
+                                     metadata_override=args.metadata, model_name=args.model_name,
+                                     split_max_tensors=args.split_max_tensors,
                                      split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
                                      small_first_shard=args.no_tensor_first_split)
 
-        logger.info("Set model parameters")
-        model_instance.gguf_writer.add_type(gguf.GGUFType.MODEL)
-        model_instance.set_gguf_parameters()
-
-        logger.info("Set model tokenizer")
-        model_instance.set_vocab()
-
-        model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
-
         if args.vocab_only:
             logger.info("Exporting model vocab...")
             model_instance.write_vocab()
@@ -3623,6 +3659,7 @@ def main() -> None:
         else:
             logger.info("Exporting model...")
             model_instance.write()
+            assert model_instance.fname_out is not None
             out_path = f"{model_instance.fname_out.parent}{os.sep}" if is_split else model_instance.fname_out
             logger.info(f"Model successfully exported to {out_path}")
 
index 4bb939d45d6bd4016d93cafd66cfe132356502fe..66e8da37cba7cf086f63065be70ee005e100babc 100755 (executable)
@@ -251,6 +251,10 @@ def parse_args() -> argparse.Namespace:
         "--verbose", action="store_true",
         help="increase output verbosity",
     )
+    parser.add_argument(
+        "--dry-run", action="store_true",
+        help="only print out what will be done, without writing any new files",
+    )
     parser.add_argument(
         "--base", type=Path, required=True,
         help="directory containing base model file",
@@ -300,6 +304,12 @@ if __name__ == '__main__':
     # load base model
     logger.info(f"Loading base model: {dir_base_model.name}")
     hparams = Model.load_hparams(dir_base_model)
+
+    with open(lora_config, "r") as f:
+        lparams: dict[str, Any] = json.load(f)
+
+    alpha: float = lparams["lora_alpha"]
+
     with torch.inference_mode():
         try:
             model_class = Model.from_model_architecture(hparams["architectures"][0])
@@ -310,6 +320,14 @@ if __name__ == '__main__':
         class LoraModel(model_class):
             model_arch = model_class.model_arch
 
+            def set_type(self):
+                self.gguf_writer.add_type(gguf.GGUFType.ADAPTER)
+                self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
+
+            def set_gguf_parameters(self):
+                self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, float(alpha))
+                super().set_gguf_parameters()
+
             def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
                 tensor_map: dict[str, PartialLoraTensor] = {}
 
@@ -357,18 +375,9 @@ if __name__ == '__main__':
             is_big_endian=args.bigendian,
             use_temp_file=False,
             eager=args.no_lazy,
-            model_name=None,
+            dry_run=args.dry_run,
         )
 
-        with open(lora_config, "r") as f:
-            lparams: dict[str, Any] = json.load(f)
-
-        alpha = lparams["lora_alpha"]
-
-        model_instance.gguf_writer.add_string(gguf.Keys.General.TYPE, gguf.GGUFType.ADAPTER)
-        model_instance.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
-        model_instance.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, float(alpha))
-        model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
         logger.info("Exporting model...")
         model_instance.write()
         logger.info(f"Model successfully exported to {model_instance.fname_out}")
index c2c73e8ad39ec78c6f976d6e51d343ff9bc07687..9ab9ab06edf8fa7eb32b72701e57ad8804bfe3c3 100755 (executable)
@@ -24,7 +24,7 @@ from abc import ABC, abstractmethod
 from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
 from dataclasses import dataclass
 from pathlib import Path
-from typing import TYPE_CHECKING, Any, Callable, IO, Iterable, Literal, TypeVar, Optional
+from typing import TYPE_CHECKING, Any, Callable, IO, Iterable, Literal, TypeVar
 
 import numpy as np
 
@@ -346,42 +346,6 @@ class Params:
         return params
 
 
-@dataclass
-class Metadata:
-    name: Optional[str] = None
-    author: Optional[str] = None
-    version: Optional[str] = None
-    url: Optional[str] = None
-    description: Optional[str] = None
-    license: Optional[str] = None
-    source_url: Optional[str] = None
-    source_hf_repo: Optional[str] = None
-
-    @staticmethod
-    def load(metadata_path: Path) -> Metadata:
-        if metadata_path is None or not metadata_path.exists():
-            return Metadata()
-
-        with open(metadata_path, 'r') as file:
-            data = json.load(file)
-
-        # Create a new Metadata instance
-        metadata = Metadata()
-
-        # Assigning values to Metadata attributes if they exist in the JSON file
-        # This is based on LLM_KV_NAMES mapping in llama.cpp
-        metadata.name = data.get("general.name")
-        metadata.author = data.get("general.author")
-        metadata.version = data.get("general.version")
-        metadata.url = data.get("general.url")
-        metadata.description = data.get("general.description")
-        metadata.license = data.get("general.license")
-        metadata.source_url = data.get("general.source.url")
-        metadata.source_hf_repo = data.get("general.source.huggingface.repository")
-
-        return metadata
-
-
 #
 # data loading
 # TODO: reuse (probably move to gguf.py?)
@@ -806,7 +770,7 @@ class OutputFile:
     def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
         self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
 
-    def add_meta_model(self, params: Params, metadata: Metadata | None) -> None:
+    def add_meta_model(self, params: Params, metadata: gguf.Metadata | None) -> None:
         # Metadata About The Model And Its Provenence
         name = "LLaMA"
         if metadata is not None and metadata.name is not None:
@@ -824,16 +788,73 @@ class OutputFile:
                 self.gguf.add_author(metadata.author)
             if metadata.version is not None:
                 self.gguf.add_version(metadata.version)
-            if metadata.url is not None:
-                self.gguf.add_url(metadata.url)
+            if metadata.organization is not None:
+                self.gguf.add_organization(metadata.organization)
+
+            if metadata.finetune is not None:
+                self.gguf.add_finetune(metadata.finetune)
+            if metadata.basename is not None:
+                self.gguf.add_basename(metadata.basename)
+
             if metadata.description is not None:
                 self.gguf.add_description(metadata.description)
+            if metadata.quantized_by is not None:
+                self.gguf.add_quantized_by(metadata.quantized_by)
+
+            if metadata.size_label is not None:
+                self.gguf.add_size_label(metadata.size_label)
+
             if metadata.license is not None:
-                self.gguf.add_licence(metadata.license)
+                self.gguf.add_license(metadata.license)
+            if metadata.license_name is not None:
+                self.gguf.add_license_name(metadata.license_name)
+            if metadata.license_link is not None:
+                self.gguf.add_license_link(metadata.license_link)
+
+            if metadata.url is not None:
+                self.gguf.add_url(metadata.url)
+            if metadata.doi is not None:
+                self.gguf.add_doi(metadata.doi)
+            if metadata.uuid is not None:
+                self.gguf.add_uuid(metadata.uuid)
+            if metadata.repo_url is not None:
+                self.gguf.add_repo_url(metadata.repo_url)
+
             if metadata.source_url is not None:
                 self.gguf.add_source_url(metadata.source_url)
-            if metadata.source_hf_repo is not None:
-                self.gguf.add_source_hf_repo(metadata.source_hf_repo)
+            if metadata.source_doi is not None:
+                self.gguf.add_source_doi(metadata.source_doi)
+            if metadata.source_uuid is not None:
+                self.gguf.add_source_uuid(metadata.source_uuid)
+            if metadata.source_repo_url is not None:
+                self.gguf.add_source_repo_url(metadata.source_repo_url)
+
+            if metadata.base_models is not None:
+                self.gguf.add_base_model_count(len(metadata.base_models))
+                for key, base_model_entry in enumerate(metadata.base_models):
+                    if "name" in base_model_entry:
+                        self.gguf.add_base_model_name(key, base_model_entry["name"])
+                    if "author" in base_model_entry:
+                        self.gguf.add_base_model_author(key, base_model_entry["author"])
+                    if "version" in base_model_entry:
+                        self.gguf.add_base_model_version(key, base_model_entry["version"])
+                    if "organization" in base_model_entry:
+                        self.gguf.add_base_model_organization(key, base_model_entry["organization"])
+                    if "url" in base_model_entry:
+                        self.gguf.add_base_model_url(key, base_model_entry["url"])
+                    if "doi" in base_model_entry:
+                        self.gguf.add_base_model_doi(key, base_model_entry["doi"])
+                    if "uuid" in base_model_entry:
+                        self.gguf.add_base_model_uuid(key, base_model_entry["uuid"])
+                    if "repo_url" in base_model_entry:
+                        self.gguf.add_base_model_repo_url(key, base_model_entry["repo_url"])
+
+            if metadata.tags is not None:
+                self.gguf.add_tags(metadata.tags)
+            if metadata.languages is not None:
+                self.gguf.add_languages(metadata.languages)
+            if metadata.datasets is not None:
+                self.gguf.add_datasets(metadata.datasets)
 
     def add_meta_arch(self, params: Params) -> None:
         # Metadata About The Neural Architecture Itself
@@ -944,7 +965,7 @@ class OutputFile:
     @staticmethod
     def write_vocab_only(
         fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
-        endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata | None = None,
+        endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: gguf.Metadata | None = None,
     ) -> None:
         check_vocab_size(params, vocab, pad_vocab=pad_vocab)
 
@@ -978,7 +999,7 @@ class OutputFile:
         fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
         concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
         pad_vocab: bool = False,
-        metadata: Metadata | None = None,
+        metadata: gguf.Metadata | None = None,
     ) -> None:
         check_vocab_size(params, vocab, pad_vocab=pad_vocab)
 
@@ -1021,35 +1042,32 @@ def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileT
     raise ValueError(f"Unexpected combination of types: {name_to_type}")
 
 
-def model_parameter_count(model: LazyModel) -> int:
-    total_model_parameters = 0
-    for i, (name, lazy_tensor) in enumerate(model.items()):
-        sum_weights_in_tensor = 1
+def per_model_weight_count_estimation(tensors: Iterable[tuple[str, LazyTensor]]) -> tuple[int, int, int]:
+    total_params = 0
+    shared_params = 0
+    expert_params = 0
+
+    for name, lazy_tensor in tensors:
+        # We don't need these
+        if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
+            continue
+
+        # Got A Tensor
+        sum_weights_in_tensor: int = 1
+
+        # Tensor Volume
         for dim in lazy_tensor.shape:
             sum_weights_in_tensor *= dim
-        total_model_parameters += sum_weights_in_tensor
-    return total_model_parameters
-
-
-def model_parameter_count_rounded_notation(model_params_count: int) -> str:
-    if model_params_count > 1e12 :
-        # Trillions Of Parameters
-        scaled_model_params = model_params_count * 1e-12
-        scale_suffix = "T"
-    elif model_params_count > 1e9 :
-        # Billions Of Parameters
-        scaled_model_params = model_params_count * 1e-9
-        scale_suffix = "B"
-    elif model_params_count > 1e6 :
-        # Millions Of Parameters
-        scaled_model_params = model_params_count * 1e-6
-        scale_suffix = "M"
-    else:
-        # Thousands Of Parameters
-        scaled_model_params = model_params_count * 1e-3
-        scale_suffix = "K"
 
-    return f"{round(scaled_model_params)}{scale_suffix}"
+        if ".experts." in name:
+            if ".experts.0." in name:
+                expert_params += sum_weights_in_tensor
+        else:
+            shared_params += sum_weights_in_tensor
+
+        total_params += sum_weights_in_tensor
+
+    return total_params, shared_params, expert_params
 
 
 def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
@@ -1231,34 +1249,24 @@ class VocabFactory:
         return vocab, special_vocab
 
 
-def default_convention_outfile(file_type: GGMLFileType, params: Params, model_params_count: int, metadata: Metadata) -> str:
-    quantization = {
+def default_convention_outfile(file_type: GGMLFileType, expert_count: int | None, model_params_count: tuple[int, int, int], metadata: gguf.Metadata) -> str:
+    name = metadata.name if metadata.name is not None else None
+    basename = metadata.basename if metadata.basename is not None else None
+    finetune = metadata.finetune if metadata.finetune is not None else None
+    version = metadata.version if metadata.version is not None else None
+    size_label = metadata.size_label if metadata.size_label is not None else gguf.size_label(*model_params_count, expert_count=expert_count or 0)
+
+    output_type = {
         GGMLFileType.AllF32:    "F32",
         GGMLFileType.MostlyF16: "F16",
         GGMLFileType.MostlyQ8_0: "Q8_0",
     }[file_type]
 
-    parameters = model_parameter_count_rounded_notation(model_params_count)
-
-    expert_count = ""
-    if params.n_experts is not None:
-        expert_count = f"{params.n_experts}x"
-
-    version = ""
-    if metadata is not None and metadata.version is not None:
-        version = f"-{metadata.version}"
+    return gguf.naming_convention(name, basename, finetune, version, size_label, output_type)
 
-    name = "ggml-model"
-    if metadata is not None and metadata.name is not None:
-        name = metadata.name
-    elif params.path_model is not None:
-        name = params.path_model.name
 
-    return f"{name}{version}-{expert_count}{parameters}-{quantization}"
-
-
-def default_outfile(model_paths: list[Path], file_type: GGMLFileType, params: Params, model_params_count: int, metadata: Metadata) -> Path:
-    default_filename = default_convention_outfile(file_type, params, model_params_count, metadata)
+def default_outfile(model_paths: list[Path], file_type: GGMLFileType, expert_count: int | None, model_params_count: tuple[int, int, int], metadata: gguf.Metadata) -> Path:
+    default_filename = default_convention_outfile(file_type, expert_count, model_params_count, metadata)
     ret = model_paths[0].parent / f"{default_filename}.gguf"
     if ret in model_paths:
         logger.error(
@@ -1297,8 +1305,9 @@ def main(args_in: list[str] | None = None) -> None:
     parser.add_argument("--pad-vocab",    action="store_true",    help="add pad tokens when model vocab expects more than tokenizer metadata provides")
     parser.add_argument("--skip-unknown", action="store_true",    help="skip unknown tensor names instead of failing")
     parser.add_argument("--verbose",      action="store_true",    help="increase output verbosity")
-    parser.add_argument("--metadata",     type=Path,              help="Specify the path for a metadata file")
+    parser.add_argument("--metadata",     type=Path,              help="Specify the path for an authorship metadata override file")
     parser.add_argument("--get-outfile",  action="store_true",    help="get calculated default outfile name")
+    parser.add_argument("--model-name",   type=str, default=None, help="name of the model")
 
     args = parser.parse_args(args_in)
 
@@ -1310,32 +1319,36 @@ def main(args_in: list[str] | None = None) -> None:
     else:
         logging.basicConfig(level=logging.INFO)
 
-    metadata = Metadata.load(args.metadata)
+    model_name = args.model_name
+    dir_model = args.model
+
+    metadata = gguf.Metadata.load(args.metadata, dir_model, model_name)
 
     if args.get_outfile:
-        model_plus = load_some_model(args.model)
+        model_plus = load_some_model(dir_model)
         params = Params.load(model_plus)
-        model   = convert_model_names(model_plus.model, params, args.skip_unknown)
-        model_params_count = model_parameter_count(model_plus.model)
-        ftype   = pick_output_type(model, args.outtype)
-        print(f"{default_convention_outfile(ftype, params, model_params_count, metadata)}") # noqa: NP100
+        model = convert_model_names(model_plus.model, params, args.skip_unknown)
+        model_params_count = per_model_weight_count_estimation(model_plus.model.items())
+        ftype = pick_output_type(model, args.outtype)
+
+        if (metadata is None or metadata.name is None) and params.path_model is not None:
+            metadata.name = params.path_model.name
+
+        print(f"{default_convention_outfile(ftype, params.n_experts, model_params_count, metadata)}") # noqa: NP100
         return
 
     if args.no_vocab and args.vocab_only:
         raise ValueError("--vocab-only does not make sense with --no-vocab")
 
     if args.dump_single:
-        model_plus = lazy_load_file(args.model)
+        model_plus = lazy_load_file(dir_model)
         do_dump_model(model_plus)
         return
 
     if not args.vocab_only:
-        model_plus = load_some_model(args.model)
+        model_plus = load_some_model(dir_model)
     else:
-        model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
-
-    model_params_count = model_parameter_count(model_plus.model)
-    logger.info(f"model parameters count : {model_params_count} ({model_parameter_count_rounded_notation(model_params_count)})")
+        model_plus = ModelPlus(model = {}, paths = [dir_model / 'dummy'], format = 'none', vocab = None)
 
     if args.dump:
         do_dump_model(model_plus)
@@ -1368,7 +1381,7 @@ def main(args_in: list[str] | None = None) -> None:
         logger.info(f"params = {params}")
 
     model_parent_path = model_plus.paths[0].parent
-    vocab_path = Path(args.vocab_dir or args.model or model_parent_path)
+    vocab_path = Path(args.vocab_dir or dir_model or model_parent_path)
     vocab_factory = VocabFactory(vocab_path)
     vocab_types = None if args.no_vocab else args.vocab_type.split(",")
     vocab, special_vocab = vocab_factory.load_vocab(vocab_types, model_parent_path)
@@ -1399,13 +1412,21 @@ def main(args_in: list[str] | None = None) -> None:
 
     assert params is not None
 
+    if metadata.name is None and params.path_model is not None:
+        metadata.name = params.path_model.name
+
+    model_params_count = per_model_weight_count_estimation(model_plus.model.items())
+    logger.info(f"model parameters count : {model_params_count} ({gguf.model_weight_count_rounded_notation(model_params_count[0])})")
+
     logger.info(f"Vocab info: {vocab}")
     logger.info(f"Special vocab info: {special_vocab}")
     model   = model_plus.model
     model   = convert_model_names(model, params, args.skip_unknown)
     ftype   = pick_output_type(model, args.outtype)
     model   = convert_to_output_type(model, ftype)
-    outfile = args.outfile or default_outfile(model_plus.paths, ftype, params, model_params_count, metadata)
+    outfile = args.outfile or default_outfile(model_plus.paths, ftype, params.n_experts, model_params_count, metadata=metadata)
+
+    metadata.size_label = gguf.size_label(*model_params_count, expert_count=params.n_experts or 0)
 
     params.ftype = ftype
     logger.info(f"Writing {outfile}, format {ftype}")
index 9dd888f3180d1459d481661d90e40731b811c7d0..24af96a17a5bb2b3f0e5ebc2266b72b8a34c6eb4 100644 (file)
@@ -78,5 +78,13 @@ python -m build
 python -m twine upload dist/*
 ```
 
+## Run Unit Tests
+
+From root of this repository you can run this command to run all the unit tests
+
+```bash
+python -m unittest discover ./gguf-py -v
+```
+
 ## TODO
 - [ ] Include conversion scripts as command line entry points in this package.
index ea5146b161bc882e8fead9849e1889a2efda28e1..243defc4c1ca42d3713017d8902592f54ac849cd 100644 (file)
@@ -5,3 +5,5 @@ from .gguf_writer import *
 from .quants import *
 from .tensor_mapping import *
 from .vocab import *
+from .utility import *
+from .metadata import *
index 5eb3df706e6e292f6bb246c278067ef9a3aac6c2..e343c2ef1659af536d8f5012a325cc6d4cfbfff5 100644 (file)
@@ -19,19 +19,60 @@ GGML_QUANT_VERSION     = 2  # GGML_QNT_VERSION from ggml.h
 
 class Keys:
     class General:
-        TYPE                 = "general.type"
-        ARCHITECTURE         = "general.architecture"
-        QUANTIZATION_VERSION = "general.quantization_version"
-        ALIGNMENT            = "general.alignment"
-        NAME                 = "general.name"
-        AUTHOR               = "general.author"
-        VERSION              = "general.version"
-        URL                  = "general.url"
-        DESCRIPTION          = "general.description"
-        LICENSE              = "general.license"
-        SOURCE_URL           = "general.source.url"
-        SOURCE_HF_REPO       = "general.source.huggingface.repository"
-        FILE_TYPE            = "general.file_type"
+        TYPE                       = "general.type"
+        ARCHITECTURE               = "general.architecture"
+        QUANTIZATION_VERSION       = "general.quantization_version"
+        ALIGNMENT                  = "general.alignment"
+        FILE_TYPE                  = "general.file_type"
+
+        # Authorship Metadata
+        NAME                       = "general.name"
+        AUTHOR                     = "general.author"
+        VERSION                    = "general.version"
+        ORGANIZATION               = "general.organization"
+
+        FINETUNE                   = "general.finetune"
+        BASENAME                   = "general.basename"
+
+        DESCRIPTION                = "general.description"
+        QUANTIZED_BY               = "general.quantized_by"
+
+        SIZE_LABEL                 = "general.size_label"
+
+        # Licensing details
+        LICENSE                    = "general.license"
+        LICENSE_NAME               = "general.license.name"
+        LICENSE_LINK               = "general.license.link"
+
+        # Typically represents the converted GGUF repo (Unless native)
+        URL                        = "general.url" # Model Website/Paper
+        DOI                        = "general.doi"
+        UUID                       = "general.uuid"
+        REPO_URL                   = "general.repo_url" # Model Source Repository (git/svn/etc...)
+
+        # Model Source during conversion
+        SOURCE_URL                 = "general.source.url" # Model Website/Paper
+        SOURCE_DOI                 = "general.source.doi"
+        SOURCE_UUID                = "general.source.uuid"
+        SOURCE_REPO_URL            = "general.source.repo_url" # Model Source Repository (git/svn/etc...)
+
+        # Base Model Source. There can be more than one source if it's a merged
+        # model like with 'Mistral-7B-Merge-14-v0.1'. This will assist in
+        # tracing linage of models as it is finetuned or merged over time.
+        BASE_MODEL_COUNT           = "general.base_model.count"
+        BASE_MODEL_NAME            = "general.base_model.{id}.name"
+        BASE_MODEL_AUTHOR          = "general.base_model.{id}.author"
+        BASE_MODEL_VERSION         = "general.base_model.{id}.version"
+        BASE_MODEL_ORGANIZATION    = "general.base_model.{id}.organization"
+        BASE_MODEL_URL             = "general.base_model.{id}.url" # Model Website/Paper
+        BASE_MODEL_DOI             = "general.base_model.{id}.doi"
+        BASE_MODEL_UUID            = "general.base_model.{id}.uuid"
+        BASE_MODEL_REPO_URL        = "general.base_model.{id}.repo_url" # Model Source Repository (git/svn/etc...)
+
+        # Array based KV stores
+        TAGS                       = "general.tags"
+        LANGUAGES                  = "general.languages"
+        DATASETS                   = "general.datasets"
 
     class LLM:
         VOCAB_SIZE                        = "{arch}.vocab_size"
@@ -1233,7 +1274,6 @@ KEY_GENERAL_URL                  = Keys.General.URL
 KEY_GENERAL_DESCRIPTION          = Keys.General.DESCRIPTION
 KEY_GENERAL_LICENSE              = Keys.General.LICENSE
 KEY_GENERAL_SOURCE_URL           = Keys.General.SOURCE_URL
-KEY_GENERAL_SOURCE_HF_REPO       = Keys.General.SOURCE_HF_REPO
 KEY_GENERAL_FILE_TYPE            = Keys.General.FILE_TYPE
 
 # LLM
index b0197961d46a830942023622eab7ea14868dcfe1..ba6f53cda25a18a55a293782774b2080da385ede 100644 (file)
@@ -7,6 +7,7 @@ import struct
 import tempfile
 from dataclasses import dataclass
 from enum import Enum, auto
+from math import prod
 from pathlib import Path
 from io import BufferedWriter
 from typing import IO, Any, Sequence, Mapping
@@ -106,6 +107,53 @@ class GGUFWriter:
 
         self.add_architecture()
 
+    def get_total_parameter_count(self) -> tuple[int, int, int, int]:
+        total_params = 0
+        shared_params = 0
+        expert_params = 0
+
+        expert_sum = 0
+        n_expert_tensors = 0
+
+        last_lora_a: tuple[str, TensorInfo] | None = None
+
+        for tensors in self.tensors:
+            for name, info in tensors.items():
+
+                shape = info.shape
+
+                if name.endswith(".lora_a"):
+                    last_lora_a = (name, info)
+                    continue
+                elif name.endswith(".lora_b"):
+                    if last_lora_a is None or last_lora_a[0] != name[:-1] + "a":
+                        # Bail when the LoRA pair can't be found trivially
+                        logger.warning("can't measure LoRA size correctly, tensor order is unusual")
+                        return 0, 0, 0, 0
+                    else:
+                        shape = (*shape[:-1], last_lora_a[1].shape[-1])
+
+                size = prod(shape)
+
+                if "_exps." in name:
+                    expert_params += (size // shape[-3])
+                    expert_sum += shape[-3]
+                    n_expert_tensors += 1
+                else:
+                    shared_params += size
+
+                total_params += size
+
+        # Hopefully this should work even for variable-expert-count models
+        expert_count = (expert_sum // n_expert_tensors) if n_expert_tensors > 0 else 0
+
+        # Negate the total to signal it's likely not exact
+        if last_lora_a is not None:
+            total_params = -total_params
+
+        # NOTE: keep the output in the same order as accepted by 'size_label' in gguf-py/gguf/utility.py
+        return total_params, shared_params, expert_params, expert_count
+
     def format_shard_names(self, path: Path) -> list[Path]:
         if len(self.tensors) == 1:
             return [path]
@@ -115,6 +163,7 @@ class GGUFWriter:
         if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
             # allow calling this multiple times as long as the path is the same
             return
+
         if self.state is not WriterState.NO_FILE:
             raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
 
@@ -136,6 +185,8 @@ class GGUFWriter:
 
         if self.dry_run:
             logger.info("Dry run, not writing files")
+            for name in filenames:
+                print(name)  # noqa: NP100
             exit()
 
         return filenames
@@ -430,43 +481,114 @@ class GGUFWriter:
     def add_architecture(self) -> None:
         self.add_string(Keys.General.ARCHITECTURE, self.arch)
 
+    def add_quantization_version(self, quantization_version: int) -> None:
+        self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version)
+
+    def add_custom_alignment(self, alignment: int) -> None:
+        self.data_alignment = alignment
+        self.add_uint32(Keys.General.ALIGNMENT, alignment)
+
+    def add_file_type(self, ftype: int) -> None:
+        self.add_uint32(Keys.General.FILE_TYPE, ftype)
+
+    def add_name(self, name: str) -> None:
+        self.add_string(Keys.General.NAME, name)
+
     def add_author(self, author: str) -> None:
         self.add_string(Keys.General.AUTHOR, author)
 
     def add_version(self, version: str) -> None:
         self.add_string(Keys.General.VERSION, version)
 
-    def add_tensor_data_layout(self, layout: str) -> None:
-        self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
+    def add_organization(self, organization: str) -> None:
+        self.add_string(Keys.General.ORGANIZATION, organization)
 
-    def add_url(self, url: str) -> None:
-        self.add_string(Keys.General.URL, url)
+    def add_finetune(self, finetune: str) -> None:
+        self.add_string(Keys.General.FINETUNE, finetune)
+
+    def add_basename(self, basename: str) -> None:
+        self.add_string(Keys.General.BASENAME, basename)
 
     def add_description(self, description: str) -> None:
         self.add_string(Keys.General.DESCRIPTION, description)
 
-    def add_licence(self, licence: str) -> None:
-        self.add_string(Keys.General.LICENSE, licence)
+    def add_quantized_by(self, quantized: str) -> None:
+        self.add_string(Keys.General.QUANTIZED_BY, quantized)
+
+    def add_size_label(self, size_label: str) -> None:
+        self.add_string(Keys.General.SIZE_LABEL, size_label)
+
+    def add_license(self, license: str) -> None:
+        self.add_string(Keys.General.LICENSE, license)
+
+    def add_license_name(self, license: str) -> None:
+        self.add_string(Keys.General.LICENSE_NAME, license)
+
+    def add_license_link(self, license: str) -> None:
+        self.add_string(Keys.General.LICENSE_LINK, license)
+
+    def add_url(self, url: str) -> None:
+        self.add_string(Keys.General.URL, url)
+
+    def add_doi(self, doi: str) -> None:
+        self.add_string(Keys.General.DOI, doi)
+
+    def add_uuid(self, uuid: str) -> None:
+        self.add_string(Keys.General.UUID, uuid)
+
+    def add_repo_url(self, repo_url: str) -> None:
+        self.add_string(Keys.General.REPO_URL, repo_url)
 
     def add_source_url(self, url: str) -> None:
         self.add_string(Keys.General.SOURCE_URL, url)
 
-    def add_source_hf_repo(self, repo: str) -> None:
-        self.add_string(Keys.General.SOURCE_HF_REPO, repo)
+    def add_source_doi(self, doi: str) -> None:
+        self.add_string(Keys.General.SOURCE_DOI, doi)
 
-    def add_file_type(self, ftype: int) -> None:
-        self.add_uint32(Keys.General.FILE_TYPE, ftype)
+    def add_source_uuid(self, uuid: str) -> None:
+        self.add_string(Keys.General.SOURCE_UUID, uuid)
 
-    def add_name(self, name: str) -> None:
-        self.add_string(Keys.General.NAME, name)
+    def add_source_repo_url(self, repo_url: str) -> None:
+        self.add_string(Keys.General.SOURCE_REPO_URL, repo_url)
 
-    def add_quantization_version(self, quantization_version: int) -> None:
-        self.add_uint32(
-            Keys.General.QUANTIZATION_VERSION, quantization_version)
+    def add_base_model_count(self, source_count: int) -> None:
+        self.add_uint32(Keys.General.BASE_MODEL_COUNT, source_count)
 
-    def add_custom_alignment(self, alignment: int) -> None:
-        self.data_alignment = alignment
-        self.add_uint32(Keys.General.ALIGNMENT, alignment)
+    def add_base_model_name(self, source_id: int, name: str) -> None:
+        self.add_string(Keys.General.BASE_MODEL_NAME.format(id=source_id), name)
+
+    def add_base_model_author(self, source_id: int, author: str) -> None:
+        self.add_string(Keys.General.BASE_MODEL_AUTHOR.format(id=source_id), author)
+
+    def add_base_model_version(self, source_id: int, version: str) -> None:
+        self.add_string(Keys.General.BASE_MODEL_VERSION.format(id=source_id), version)
+
+    def add_base_model_organization(self, source_id: int, organization: str) -> None:
+        self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization)
+
+    def add_base_model_url(self, source_id: int, url: str) -> None:
+        self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url)
+
+    def add_base_model_doi(self, source_id: int, doi: str) -> None:
+        self.add_string(Keys.General.BASE_MODEL_DOI.format(id=source_id), doi)
+
+    def add_base_model_uuid(self, source_id: int, uuid: str) -> None:
+        self.add_string(Keys.General.BASE_MODEL_UUID.format(id=source_id), uuid)
+
+    def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None:
+        self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url)
+
+    def add_tags(self, tags: Sequence[str]) -> None:
+        self.add_array(Keys.General.TAGS, tags)
+
+    def add_languages(self, languages: Sequence[str]) -> None:
+        self.add_array(Keys.General.LANGUAGES, languages)
+
+    def add_datasets(self, datasets: Sequence[str]) -> None:
+        self.add_array(Keys.General.DATASETS, datasets)
+
+    def add_tensor_data_layout(self, layout: str) -> None:
+        self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
 
     def add_vocab_size(self, size: int) -> None:
         self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size)
diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py
new file mode 100644 (file)
index 0000000..be297f2
--- /dev/null
@@ -0,0 +1,485 @@
+from __future__ import annotations
+
+import re
+import json
+import yaml
+import logging
+from pathlib import Path
+from typing import Any, Literal, Optional
+from dataclasses import dataclass
+
+from .constants import Keys
+
+import gguf
+
+logger = logging.getLogger("metadata")
+
+
+@dataclass
+class Metadata:
+    # Authorship Metadata to be written to GGUF KV Store
+    name: Optional[str] = None
+    author: Optional[str] = None
+    version: Optional[str] = None
+    organization: Optional[str] = None
+    finetune: Optional[str] = None
+    basename: Optional[str] = None
+    description: Optional[str] = None
+    quantized_by: Optional[str] = None
+    size_label: Optional[str] = None
+    url: Optional[str] = None
+    doi: Optional[str] = None
+    uuid: Optional[str] = None
+    repo_url: Optional[str] = None
+    source_url: Optional[str] = None
+    source_doi: Optional[str] = None
+    source_uuid: Optional[str] = None
+    source_repo_url: Optional[str] = None
+    license: Optional[str] = None
+    license_name: Optional[str] = None
+    license_link: Optional[str] = None
+    base_models: Optional[list[dict]] = None
+    tags: Optional[list[str]] = None
+    languages: Optional[list[str]] = None
+    datasets: Optional[list[str]] = None
+
+    @staticmethod
+    def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
+        # This grabs as many contextual authorship metadata as possible from the model repository
+        # making any conversion as required to match the gguf kv store metadata format
+        # as well as giving users the ability to override any authorship metadata that may be incorrect
+
+        # Create a new Metadata instance
+        metadata = Metadata()
+
+        model_card = Metadata.load_model_card(model_path)
+        hf_params = Metadata.load_hf_parameters(model_path)
+
+        # heuristics
+        metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
+
+        # Metadata Override File Provided
+        # This is based on LLM_KV_NAMES mapping in llama.cpp
+        metadata_override = Metadata.load_metadata_override(metadata_override_path)
+
+        metadata.author          = metadata_override.get(Keys.General.AUTHOR,          metadata.author)
+        metadata.version         = metadata_override.get(Keys.General.VERSION,         metadata.version)
+        metadata.organization    = metadata_override.get(Keys.General.ORGANIZATION,    metadata.organization)
+
+        metadata.finetune        = metadata_override.get(Keys.General.FINETUNE,        metadata.finetune)
+        metadata.basename        = metadata_override.get(Keys.General.BASENAME,        metadata.basename)
+
+        metadata.description     = metadata_override.get(Keys.General.DESCRIPTION,     metadata.description)
+        metadata.quantized_by    = metadata_override.get(Keys.General.QUANTIZED_BY,    metadata.quantized_by)
+
+        metadata.size_label      = metadata_override.get(Keys.General.SIZE_LABEL,      metadata.size_label)
+        metadata.license_name    = metadata_override.get(Keys.General.LICENSE_NAME,    metadata.license_name)
+        metadata.license_link    = metadata_override.get(Keys.General.LICENSE_LINK,    metadata.license_link)
+
+        metadata.url             = metadata_override.get(Keys.General.URL,             metadata.url)
+        metadata.doi             = metadata_override.get(Keys.General.DOI,             metadata.doi)
+        metadata.uuid            = metadata_override.get(Keys.General.UUID,            metadata.uuid)
+        metadata.repo_url        = metadata_override.get(Keys.General.REPO_URL,        metadata.repo_url)
+
+        metadata.source_url      = metadata_override.get(Keys.General.SOURCE_URL,      metadata.source_url)
+        metadata.source_doi      = metadata_override.get(Keys.General.SOURCE_DOI,      metadata.source_doi)
+        metadata.source_uuid     = metadata_override.get(Keys.General.SOURCE_UUID,     metadata.source_uuid)
+        metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url)
+
+        # Base Models is received here as an array of models
+        metadata.base_models     = metadata_override.get("general.base_models",        metadata.base_models)
+
+        metadata.tags            = metadata_override.get(Keys.General.TAGS,            metadata.tags)
+        metadata.languages       = metadata_override.get(Keys.General.LANGUAGES,       metadata.languages)
+        metadata.datasets        = metadata_override.get(Keys.General.DATASETS,        metadata.datasets)
+
+        # Direct Metadata Override (via direct cli argument)
+        if model_name is not None:
+            metadata.name = model_name
+
+        return metadata
+
+    @staticmethod
+    def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, Any]:
+        if metadata_override_path is None or not metadata_override_path.is_file():
+            return {}
+
+        with open(metadata_override_path, "r", encoding="utf-8") as f:
+            return json.load(f)
+
+    @staticmethod
+    def load_model_card(model_path: Optional[Path] = None) -> dict[str, Any]:
+        if model_path is None or not model_path.is_dir():
+            return {}
+
+        model_card_path = model_path / "README.md"
+
+        if not model_card_path.is_file():
+            return {}
+
+        # The model card metadata is assumed to always be in YAML
+        # ref: https://github.com/huggingface/transformers/blob/a5c642fe7a1f25d3bdcd76991443ba6ff7ee34b2/src/transformers/modelcard.py#L468-L473
+        with open(model_card_path, "r", encoding="utf-8") as f:
+            if f.readline() == "---\n":
+                raw = f.read().partition("---\n")[0]
+                data = yaml.safe_load(raw)
+                if isinstance(data, dict):
+                    return data
+                else:
+                    logger.error(f"while reading YAML model card frontmatter, data is {type(data)} instead of dict")
+                    return {}
+            else:
+                return {}
+
+    @staticmethod
+    def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]:
+        if model_path is None or not model_path.is_dir():
+            return {}
+
+        config_path = model_path / "config.json"
+
+        if not config_path.is_file():
+            return {}
+
+        with open(config_path, "r", encoding="utf-8") as f:
+            return json.load(f)
+
+    @staticmethod
+    def id_to_title(string):
+        # Convert capitalization into title form unless acronym or version number
+        return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()])
+
+    @staticmethod
+    def get_model_id_components(model_id: Optional[str] = None, total_params: int = 0) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]:
+        # Huggingface often store model id as '<org>/<model name>'
+        # so let's parse it and apply some heuristics if possible for model name components
+
+        if model_id is None:
+            # model ID missing
+            return None, None, None, None, None, None
+
+        if ' ' in model_id:
+            # model ID is actually a normal human sentence
+            # which means its most likely a normal model name only
+            # not part of the hugging face naming standard, but whatever
+            return model_id, None, None, None, None, None
+
+        if '/' in model_id:
+            # model ID (huggingface style)
+            org_component, model_full_name_component = model_id.split('/', 1)
+        else:
+            # model ID but missing org components
+            org_component, model_full_name_component = None, model_id
+
+        # Check if we erroneously matched against './' or '../' etc...
+        if org_component is not None and org_component[0] == '.':
+            org_component = None
+
+        name_parts: list[str] = model_full_name_component.split('-')
+        name_types: list[
+            set[Literal["basename", "size_label", "finetune", "version", "type"]]
+        ] = [set() for _ in name_parts]
+
+        # Annotate the name
+        for i, part in enumerate(name_parts):
+            # Version
+            if re.fullmatch(r'(v|iter)?\d+([.]\d+)*', part, re.IGNORECASE):
+                name_types[i].add("version")
+            # Quant type (should not be there for base models, but still annotated)
+            elif re.fullmatch(r'i?q\d(_\w)*|b?fp?(16|32)', part, re.IGNORECASE):
+                name_types[i].add("type")
+                name_parts[i] = part.upper()
+            # Model size
+            elif i > 0 and re.fullmatch(r'(([A]|\d+[x])?\d+([._]\d+)?[KMBT][\d]?|small|mini|medium|large|x?xl)', part, re.IGNORECASE):
+                part = part.replace("_", ".")
+                # Handle weird bloom-7b1 notation
+                if part[-1].isdecimal():
+                    part = part[:-2] + "." + part[-1] + part[-2]
+                # Normalize the size suffixes
+                if len(part) > 1 and part[-2].isdecimal():
+                    if part[-1] in "kmbt":
+                        part = part[:-1] + part[-1].upper()
+                if total_params != 0:
+                    try:
+                        label_params = float(part[:-1]) * pow(1000, " KMBT".find(part[-1]))
+                        # Only use it as a size label if it's close or bigger than the model size
+                        # Note that LoRA adapters don't necessarily include all layers,
+                        # so this is why bigger label sizes are accepted.
+                        # Do not use the size label when it's smaller than 1/8 of the model size
+                        if (total_params < 0 and label_params < abs(total_params) // 8) or (
+                            # Check both directions when the current model isn't a LoRA adapter
+                            total_params > 0 and abs(label_params - total_params) > 7 * total_params // 8
+                        ):
+                            # Likely a context length
+                            name_types[i].add("finetune")
+                            # Lowercase the size when it's a context length
+                            part = part[:-1] + part[-1].lower()
+                    except ValueError:
+                        # Failed to convert the size label to float, use it anyway
+                        pass
+                if len(name_types[i]) == 0:
+                    name_types[i].add("size_label")
+                name_parts[i] = part
+            # Some easy to recognize finetune names
+            elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE):
+                name_types[i].add("finetune")
+                if part.lower() == "lora":
+                    name_parts[i] = "LoRA"
+
+        at_start = True
+        # Find the basename through the annotated name
+        for part, t in zip(name_parts, name_types):
+            if at_start and ((len(t) == 0 and part[0].isalpha()) or "version" in t):
+                t.add("basename")
+            else:
+                if at_start:
+                    at_start = False
+                if len(t) == 0:
+                    t.add("finetune")
+
+        # Remove the basename annotation from trailing version
+        for part, t in zip(reversed(name_parts), reversed(name_types)):
+            if "basename" in t:
+                if len(t) > 1:
+                    t.remove("basename")
+            else:
+                break
+
+        basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None
+        size_label = "-".join(s for s, t in zip(name_parts, name_types) if "size_label" in t) or None
+        finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None
+        # TODO: should the basename version always be excluded?
+        # TODO: should multiple versions be joined together?
+        version = ([v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t] or [None])[-1]
+
+        if size_label is None and finetune is None and version is None:
+            # Too ambiguous, output nothing
+            basename = None
+
+        return model_full_name_component, org_component, basename, finetune, version, size_label
+
+    @staticmethod
+    def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None, total_params: int = 0) -> Metadata:
+        # Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
+
+        # Model Card Heuristics
+        ########################
+        if model_card is not None:
+
+            if "model_name" in model_card and metadata.name is None:
+                # Not part of huggingface model card standard but notice some model creator using it
+                # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
+                metadata.name = model_card.get("model_name")
+
+            if "model_creator" in model_card and metadata.author is None:
+                # Not part of huggingface model card standard but notice some model creator using it
+                # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
+                metadata.author = model_card.get("model_creator")
+
+            if "model_type" in model_card and metadata.basename is None:
+                # Not part of huggingface model card standard but notice some model creator using it
+                # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
+                metadata.basename = model_card.get("model_type")
+
+            if "base_model" in model_card:
+                # This represents the parent models that this is based on
+                # Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
+                # Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
+                metadata_base_models = []
+                base_model_value = model_card.get("base_model", None)
+
+                if base_model_value is not None:
+                    if isinstance(base_model_value, str):
+                        metadata_base_models.append(base_model_value)
+                    elif isinstance(base_model_value, list):
+                        metadata_base_models.extend(base_model_value)
+
+                if metadata.base_models is None:
+                    metadata.base_models = []
+
+                for model_id in metadata_base_models:
+                    # NOTE: model size of base model is assumed to be similar to the size of the current model
+                    model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
+                    base_model = {}
+                    if model_full_name_component is not None:
+                        base_model["name"] = Metadata.id_to_title(model_full_name_component)
+                    if org_component is not None:
+                        base_model["organization"] = Metadata.id_to_title(org_component)
+                    if version is not None:
+                        base_model["version"] = version
+                    if org_component is not None and model_full_name_component is not None:
+                        base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
+                    metadata.base_models.append(base_model)
+
+            if "license" in model_card and metadata.license is None:
+                metadata.license = model_card.get("license")
+
+            if "license_name" in model_card and metadata.license_name is None:
+                metadata.license_name = model_card.get("license_name")
+
+            if "license_link" in model_card and metadata.license_link is None:
+                metadata.license_link = model_card.get("license_link")
+
+            tags_value = model_card.get("tags", None)
+            if tags_value is not None:
+
+                if metadata.tags is None:
+                    metadata.tags = []
+
+                if isinstance(tags_value, str):
+                    metadata.tags.append(tags_value)
+                elif isinstance(tags_value, list):
+                    metadata.tags.extend(tags_value)
+
+            pipeline_tags_value = model_card.get("pipeline_tag", None)
+            if pipeline_tags_value is not None:
+
+                if metadata.tags is None:
+                    metadata.tags = []
+
+                if isinstance(pipeline_tags_value, str):
+                    metadata.tags.append(pipeline_tags_value)
+                elif isinstance(pipeline_tags_value, list):
+                    metadata.tags.extend(pipeline_tags_value)
+
+            language_value = model_card.get("languages", model_card.get("language", None))
+            if language_value is not None:
+
+                if metadata.languages is None:
+                    metadata.languages = []
+
+                if isinstance(language_value, str):
+                    metadata.languages.append(language_value)
+                elif isinstance(language_value, list):
+                    metadata.languages.extend(language_value)
+
+            dataset_value = model_card.get("datasets", model_card.get("dataset", None))
+            if dataset_value is not None:
+
+                if metadata.datasets is None:
+                    metadata.datasets = []
+
+                if isinstance(dataset_value, str):
+                    metadata.datasets.append(dataset_value)
+                elif isinstance(dataset_value, list):
+                    metadata.datasets.extend(dataset_value)
+
+        # Hugging Face Parameter Heuristics
+        ####################################
+
+        if hf_params is not None:
+
+            hf_name_or_path = hf_params.get("_name_or_path")
+            if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1:
+                # Use _name_or_path only if its actually a model name and not some computer path
+                # e.g. 'meta-llama/Llama-2-7b-hf'
+                model_id = hf_name_or_path
+                model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
+                if metadata.name is None and model_full_name_component is not None:
+                    metadata.name = Metadata.id_to_title(model_full_name_component)
+                if metadata.organization is None and org_component is not None:
+                    metadata.organization = Metadata.id_to_title(org_component)
+                if metadata.basename is None and basename is not None:
+                    metadata.basename = basename
+                if metadata.finetune is None and finetune is not None:
+                    metadata.finetune = finetune
+                if metadata.version is None and version is not None:
+                    metadata.version = version
+                if metadata.size_label is None and size_label is not None:
+                    metadata.size_label = size_label
+
+        # Directory Folder Name Fallback Heuristics
+        ############################################
+        if model_path is not None:
+            model_id = model_path.name
+            model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
+            if metadata.name is None and model_full_name_component is not None:
+                metadata.name = Metadata.id_to_title(model_full_name_component)
+            if metadata.organization is None and org_component is not None:
+                metadata.organization = Metadata.id_to_title(org_component)
+            if metadata.basename is None and basename is not None:
+                metadata.basename = basename
+            if metadata.finetune is None and finetune is not None:
+                metadata.finetune = finetune
+            if metadata.version is None and version is not None:
+                metadata.version = version
+            if metadata.size_label is None and size_label is not None:
+                metadata.size_label = size_label
+
+        return metadata
+
+    def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
+        assert self.name is not None
+        gguf_writer.add_name(self.name)
+
+        if self.author is not None:
+            gguf_writer.add_author(self.author)
+        if self.version is not None:
+            gguf_writer.add_version(self.version)
+        if self.organization is not None:
+            gguf_writer.add_organization(self.organization)
+
+        if self.finetune is not None:
+            gguf_writer.add_finetune(self.finetune)
+        if self.basename is not None:
+            gguf_writer.add_basename(self.basename)
+
+        if self.description is not None:
+            gguf_writer.add_description(self.description)
+        if self.quantized_by is not None:
+            gguf_writer.add_quantized_by(self.quantized_by)
+
+        if self.size_label is not None:
+            gguf_writer.add_size_label(self.size_label)
+
+        if self.license is not None:
+            gguf_writer.add_license(self.license)
+        if self.license_name is not None:
+            gguf_writer.add_license_name(self.license_name)
+        if self.license_link is not None:
+            gguf_writer.add_license_link(self.license_link)
+
+        if self.url is not None:
+            gguf_writer.add_url(self.url)
+        if self.doi is not None:
+            gguf_writer.add_doi(self.doi)
+        if self.uuid is not None:
+            gguf_writer.add_uuid(self.uuid)
+        if self.repo_url is not None:
+            gguf_writer.add_repo_url(self.repo_url)
+
+        if self.source_url is not None:
+            gguf_writer.add_source_url(self.source_url)
+        if self.source_doi is not None:
+            gguf_writer.add_source_doi(self.source_doi)
+        if self.source_uuid is not None:
+            gguf_writer.add_source_uuid(self.source_uuid)
+        if self.source_repo_url is not None:
+            gguf_writer.add_source_repo_url(self.source_repo_url)
+
+        if self.base_models is not None:
+            gguf_writer.add_base_model_count(len(self.base_models))
+            for key, base_model_entry in enumerate(self.base_models):
+                if "name" in base_model_entry:
+                    gguf_writer.add_base_model_name(key, base_model_entry["name"])
+                if "author" in base_model_entry:
+                    gguf_writer.add_base_model_author(key, base_model_entry["author"])
+                if "version" in base_model_entry:
+                    gguf_writer.add_base_model_version(key, base_model_entry["version"])
+                if "organization" in base_model_entry:
+                    gguf_writer.add_base_model_organization(key, base_model_entry["organization"])
+                if "url" in base_model_entry:
+                    gguf_writer.add_base_model_url(key, base_model_entry["url"])
+                if "doi" in base_model_entry:
+                    gguf_writer.add_base_model_doi(key, base_model_entry["doi"])
+                if "uuid" in base_model_entry:
+                    gguf_writer.add_base_model_uuid(key, base_model_entry["uuid"])
+                if "repo_url" in base_model_entry:
+                    gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"])
+
+        if self.tags is not None:
+            gguf_writer.add_tags(self.tags)
+        if self.languages is not None:
+            gguf_writer.add_languages(self.languages)
+        if self.datasets is not None:
+            gguf_writer.add_datasets(self.datasets)
diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py
new file mode 100644 (file)
index 0000000..ef76831
--- /dev/null
@@ -0,0 +1,69 @@
+from __future__ import annotations
+
+from typing import Literal
+
+
+def fill_templated_filename(filename: str, output_type: str | None) -> str:
+    # Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
+    ftype_lowercase: str = output_type.lower() if output_type is not None else ""
+    ftype_uppercase: str = output_type.upper() if output_type is not None else ""
+    return filename.format(ftype_lowercase,
+                           outtype=ftype_lowercase, ftype=ftype_lowercase,
+                           OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase)
+
+
+def model_weight_count_rounded_notation(model_params_count: int, min_digits: int = 2) -> str:
+    if model_params_count > 1e12 :
+        # Trillions Of Parameters
+        scaled_model_params = model_params_count * 1e-12
+        scale_suffix = "T"
+    elif model_params_count > 1e9 :
+        # Billions Of Parameters
+        scaled_model_params = model_params_count * 1e-9
+        scale_suffix = "B"
+    elif model_params_count > 1e6 :
+        # Millions Of Parameters
+        scaled_model_params = model_params_count * 1e-6
+        scale_suffix = "M"
+    else:
+        # Thousands Of Parameters
+        scaled_model_params = model_params_count * 1e-3
+        scale_suffix = "K"
+
+    fix = max(min_digits - len(str(round(scaled_model_params)).lstrip('0')), 0)
+
+    return f"{scaled_model_params:.{fix}f}{scale_suffix}"
+
+
+def size_label(total_params: int, shared_params: int, expert_params: int, expert_count: int) -> str:
+
+    if expert_count > 0:
+        pretty_size = model_weight_count_rounded_notation(abs(shared_params) + abs(expert_params), min_digits=2)
+        size_class = f"{expert_count}x{pretty_size}"
+    else:
+        size_class = model_weight_count_rounded_notation(abs(total_params), min_digits=2)
+
+    return size_class
+
+
+def naming_convention(model_name: str | None, base_name: str | None, finetune_string: str | None, version_string: str | None, size_label: str | None, output_type: str | None, model_type: Literal['vocab', 'LoRA'] | None = None) -> str:
+    # Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention
+
+    if base_name is not None:
+        name = base_name.strip().title().replace(' ', '-').replace('/', '-')
+    elif model_name is not None:
+        name = model_name.strip().title().replace(' ', '-').replace('/', '-')
+    else:
+        name = "ggml-model"
+
+    parameters = f"-{size_label}" if size_label is not None else ""
+
+    finetune = f"-{finetune_string.strip().title().replace(' ', '-')}" if finetune_string is not None else ""
+
+    version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else ""
+
+    encoding = f"-{output_type.strip().replace(' ', '-').upper()}" if output_type is not None else ""
+
+    kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
+
+    return f"{name}{parameters}{finetune}{version}{encoding}{kind}"
index 62129126bdddcffdeb23dfab95b77a70fde6b2ab..19f6761e2f91259fd0af477b8e4614046bc18735 100644 (file)
@@ -22,6 +22,7 @@ classifiers = [
 python = ">=3.8"
 numpy = ">=1.17"
 tqdm = ">=4.27"
+pyyaml = ">=5.1"
 
 [tool.poetry.dev-dependencies]
 pytest = "^5.2"
diff --git a/gguf-py/tests/__init__.py b/gguf-py/tests/__init__.py
new file mode 100644 (file)
index 0000000..d23ff9c
--- /dev/null
@@ -0,0 +1 @@
+from .test_metadata import *
diff --git a/gguf-py/tests/test_gguf.py b/gguf-py/tests/test_gguf.py
deleted file mode 100644 (file)
index 76b5218..0000000
+++ /dev/null
@@ -1,7 +0,0 @@
-import gguf  # noqa: F401  # pyright: ignore[reportUnusedImport]
-
-# TODO: add tests
-
-
-def test_write_gguf() -> None:
-    pass
diff --git a/gguf-py/tests/test_metadata.py b/gguf-py/tests/test_metadata.py
new file mode 100755 (executable)
index 0000000..3fac821
--- /dev/null
@@ -0,0 +1,158 @@
+#!/usr/bin/env python3
+
+import unittest
+from pathlib import Path
+import os
+import sys
+
+# Necessary to load the local gguf package
+if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
+    sys.path.insert(0, str(Path(__file__).parent.parent))
+
+import gguf
+
+
+class TestMetadataMethod(unittest.TestCase):
+
+    def test_id_to_title(self):
+        self.assertEqual(gguf.Metadata.id_to_title("Mixtral-8x7B-Instruct-v0.1"), "Mixtral 8x7B Instruct v0.1")
+        self.assertEqual(gguf.Metadata.id_to_title("Meta-Llama-3-8B"), "Meta Llama 3 8B")
+        self.assertEqual(gguf.Metadata.id_to_title("hermes-2-pro-llama-3-8b-DPO"), "Hermes 2 Pro Llama 3 8b DPO")
+
+    def test_get_model_id_components(self):
+        # This is the basic standard form with organization marker
+        self.assertEqual(gguf.Metadata.get_model_id_components("Mistral/Mixtral-8x7B-Instruct-v0.1"),
+                         ('Mixtral-8x7B-Instruct-v0.1', "Mistral", 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
+
+        # Similar to basic standard form but without organization marker
+        self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct-v0.1"),
+                         ('Mixtral-8x7B-Instruct-v0.1', None, 'Mixtral', 'Instruct', 'v0.1', '8x7B'))
+
+        # Missing version
+        self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct"),
+                         ('Mixtral-8x7B-Instruct', None, 'Mixtral', 'Instruct', None, '8x7B'))
+
+        # Missing finetune
+        self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-v0.1"),
+                         ('Mixtral-8x7B-v0.1', None, 'Mixtral', None, 'v0.1', '8x7B'))
+
+        # Base name and size label only
+        self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B"),
+                         ('Mixtral-8x7B', None, 'Mixtral', None, None, '8x7B'))
+
+        # Base name and version only
+        self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-v0.1"),
+                         ('Mixtral-v0.1', None, 'Mixtral', None, 'v0.1', None))
+
+        ## Edge Cases ##
+
+        # This is too ambiguous... best to err on caution and output nothing
+        self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral"),
+                         ('Mixtral', None, None, None, None, None))
+
+        # Basename has numbers mixed in and also size label provided. Must avoid capturing number in basename
+        self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"),
+                         ('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, '8B'))
+
+        # Can't detect all non standard form in a heuristically safe way... best to err in caution and output nothing...
+        self.assertEqual(gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"),
+                         ('Qwen1.5-MoE-A2.7B-Chat', None, 'Qwen1.5-MoE', 'Chat', None, 'A2.7B'))
+
+        # Capture 'sub size labels' e.g. A14B in '57B-A14B' usually refers to activated params/weight count
+        self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-57B-A14B-Instruct"),
+                         ('Qwen2-57B-A14B-Instruct', None, 'Qwen2', 'Instruct', None, '57B-A14B'))
+
+        # Check that it can handle a real model id with no version code
+        # Note that 4k in this string is non standard and microsoft were referring to context length rather than weight count
+        self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Phi-3-mini-4k-instruct", 4 * 10**9),
+                         ('Phi-3-mini-4k-instruct', 'microsoft', 'Phi-3', '4k-instruct', None, 'mini'))
+
+        # There is some legitimate models with only thousands of parameters
+        self.assertEqual(gguf.Metadata.get_model_id_components("delphi-suite/stories-llama2-50k", 50 * 10**3),
+                         ('stories-llama2-50k', 'delphi-suite', 'stories-llama2', None, None, '50K'))
+
+        # None standard and not easy to disambiguate
+        self.assertEqual(gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"),
+                         ('DeepSeek-Coder-V2-Lite-Instruct', None, 'DeepSeek-Coder-V2-Lite', 'Instruct', None, None))
+
+        # This is a real model_id where they append 2DPO to refer to Direct Preference Optimization
+        self.assertEqual(gguf.Metadata.get_model_id_components("crestf411/daybreak-kunoichi-2dpo-7b"),
+                         ('daybreak-kunoichi-2dpo-7b', 'crestf411', 'daybreak-kunoichi', '2dpo', None, '7B'))
+
+        # This is a real model id where the weight size has a decimal point
+        self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-0.5B-Instruct"),
+                         ('Qwen2-0.5B-Instruct', None, 'Qwen2', 'Instruct', None, '0.5B'))
+
+        # Uses an underscore in the size label
+        self.assertEqual(gguf.Metadata.get_model_id_components("smallcloudai/Refact-1_6B-fim"),
+                         ('Refact-1_6B-fim', 'smallcloudai', 'Refact', 'fim', None, '1.6B'))
+
+        # Uses Iter3 for the version
+        self.assertEqual(gguf.Metadata.get_model_id_components("UCLA-AGI/Gemma-2-9B-It-SPPO-Iter3"),
+                         ('Gemma-2-9B-It-SPPO-Iter3', 'UCLA-AGI', 'Gemma-2', 'It-SPPO', 'Iter3', '9B'))
+
+        # Has two potential versions in the basename
+        self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Hermes-2-Theta-Llama-3-8B"),
+                         ('Hermes-2-Theta-Llama-3-8B', 'NousResearch', 'Hermes-2-Theta-Llama-3', None, None, '8B'))
+
+        # Potential version in the basename
+        self.assertEqual(gguf.Metadata.get_model_id_components("SeaLLMs/SeaLLMs-v3-7B-Chat"),
+                         ('SeaLLMs-v3-7B-Chat', 'SeaLLMs', 'SeaLLMs-v3', 'Chat', None, '7B'))
+
+        # Underscore in the basename, and 1m for the context size
+        self.assertEqual(gguf.Metadata.get_model_id_components("internlm/internlm2_5-7b-chat-1m", 7 * 10**9),
+                         ('internlm2_5-7b-chat-1m', 'internlm', 'internlm2_5', 'chat-1m', None, '7B'))
+
+        # Version before the finetune name
+        self.assertEqual(gguf.Metadata.get_model_id_components("pszemraj/jamba-900M-v0.13-KIx2"),
+                         ('jamba-900M-v0.13-KIx2', 'pszemraj', 'jamba', 'KIx2', 'v0.13', '900M'))
+
+        # TODO: hf suffix which could be ignored but isn't
+        self.assertEqual(gguf.Metadata.get_model_id_components("state-spaces/mamba-2.8b-hf"),
+                         ('mamba-2.8b-hf', 'state-spaces', 'mamba', 'hf', None, '2.8B'))
+
+        # Two sizes, don't merge them, the other is the number of tokens on which it was trained
+        self.assertEqual(gguf.Metadata.get_model_id_components("abacaj/llama-161M-100B", 161 * 10**6),
+                         ('llama-161M-100B', 'abacaj', 'llama', '100b', None, '161M'))
+
+        # It's a trap, there is no size label
+        self.assertEqual(gguf.Metadata.get_model_id_components("SparseLLM/relu-100B", 1340 * 10**6),
+                         ('relu-100B', 'SparseLLM', 'relu', '100b', None, None))
+
+        # Weird size notation
+        self.assertEqual(gguf.Metadata.get_model_id_components("bigscience/bloom-7b1-petals"),
+                         ('bloom-7b1-petals', 'bigscience', 'bloom', 'petals', None, '7.1B'))
+
+    def test_apply_metadata_heuristic_from_model_card(self):
+        model_card = {
+            'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'],
+            'model-index': [{'name': 'Mixtral-8x7B-Instruct-v0.1', 'results': []}],
+            'language': ['en'],
+            'datasets': ['teknium/OpenHermes-2.5'],
+            'widget': [{'example_title': 'Hermes 2 Pro', 'messages': [{'role': 'system', 'content': 'You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.'}, {'role': 'user', 'content': 'Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.'}]}],
+            'base_model': ["EmbeddedLLM/Mistral-7B-Merge-14-v0", "janai-hq/trinity-v1"]
+        }
+        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
+        expect = gguf.Metadata()
+        expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': 'v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}]
+        expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl']
+        expect.languages=['en']
+        expect.datasets=['teknium/OpenHermes-2.5']
+
+        self.assertEqual(got, expect)
+
+    def test_apply_metadata_heuristic_from_hf_parameters(self):
+        hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"}
+        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=hf_params, model_path=None)
+        expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B')
+        self.assertEqual(got, expect)
+
+    def test_apply_metadata_heuristic_from_model_dir(self):
+        model_dir_path = Path("./hermes-2-pro-llama-3-8b-DPO")
+        got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=None, model_path=model_dir_path)
+        expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B')
+        self.assertEqual(got, expect)
+
+
+if __name__ == "__main__":
+    unittest.main()