from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable
+from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable, Optional
import numpy as np
from sentencepiece import SentencePieceProcessor
return params
+@dataclass
+class Metadata:
+ name: Optional[str] = None
+ author: Optional[str] = None
+ version: Optional[str] = None
+ url: Optional[str] = None
+ description: Optional[str] = None
+ licence: Optional[str] = None
+ source_url: Optional[str] = None
+ source_hf_repo: Optional[str] = None
+
+ @staticmethod
+ def load(metadata_path: Path) -> Metadata:
+ if metadata_path is None or not metadata_path.exists():
+ return Metadata()
+
+ with open(metadata_path, 'r') as file:
+ data = json.load(file)
+
+ # Create a new Metadata instance
+ metadata = Metadata()
+
+ # Assigning values to Metadata attributes if they exist in the JSON file
+ # This is based on LLM_KV_NAMES mapping in llama.cpp
+ metadata.name = data.get("general.name")
+ metadata.author = data.get("general.author")
+ metadata.version = data.get("general.version")
+ metadata.url = data.get("general.url")
+ metadata.description = data.get("general.description")
+ metadata.license = data.get("general.license")
+ metadata.source_url = data.get("general.source.url")
+ metadata.source_hf_repo = data.get("general.source.huggingface.repository")
+
+ return metadata
+
+
#
# vocab
#
+
@runtime_checkable
class BaseVocab(Protocol):
tokenizer_model: ClassVar[str]
def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
- def add_meta_arch(self, params: Params) -> None:
+ def add_meta_model(self, params: Params, metadata: Metadata) -> None:
+ # Metadata About The Model And Its Provenence
name = "LLaMA"
-
- # TODO: better logic to determine model name
- if params.n_ctx == 4096:
- name = "LLaMA v2"
+ if metadata is not None and metadata.name is not None:
+ name = metadata.name
elif params.path_model is not None:
- name = str(params.path_model.parent).split('/')[-1]
-
- self.gguf.add_name (name)
- self.gguf.add_vocab_size (params.n_vocab)
- self.gguf.add_context_length (params.n_ctx)
- self.gguf.add_embedding_length (params.n_embd)
- self.gguf.add_block_count (params.n_layer)
- self.gguf.add_feed_forward_length (params.n_ff)
+ name = str(params.path_model.parent).split("/")[-1]
+ elif params.n_ctx == 4096:
+ # Heuristic detection of LLaMA v2 model
+ name = "LLaMA v2"
+
+ self.gguf.add_name(name)
+
+ if metadata is not None:
+ if metadata.author is not None:
+ self.gguf.add_author(metadata.author)
+ if metadata.version is not None:
+ self.gguf.add_version(metadata.version)
+ if metadata.url is not None:
+ self.gguf.add_url(metadata.url)
+ if metadata.description is not None:
+ self.gguf.add_description(metadata.description)
+ if metadata.licence is not None:
+ self.gguf.add_licence(metadata.licence)
+ if metadata.source_url is not None:
+ self.gguf.add_source_url(metadata.source_url)
+ if metadata.source_hf_repo is not None:
+ self.gguf.add_source_hf_repo(metadata.source_hf_repo)
+
+ def add_meta_arch(self, params: Params) -> None:
+ # Metadata About The Neural Architecture Itself
+ self.gguf.add_vocab_size(params.n_vocab)
+ self.gguf.add_context_length(params.n_ctx)
+ self.gguf.add_embedding_length(params.n_embd)
+ self.gguf.add_block_count(params.n_layer)
+ self.gguf.add_feed_forward_length(params.n_ff)
self.gguf.add_rope_dimension_count(params.n_embd // params.n_head)
self.gguf.add_head_count (params.n_head)
self.gguf.add_head_count_kv (params.n_head_kv)
@staticmethod
def write_vocab_only(
fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
- endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False,
+ endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata = None,
) -> None:
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
of = OutputFile(fname_out, endianess=endianess)
# meta data
+ of.add_meta_model(params, metadata)
of.add_meta_arch(params)
of.add_meta_vocab(vocab)
of.add_meta_special_vocab(svocab)
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
pad_vocab: bool = False,
+ metadata: Metadata = None,
) -> None:
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
of = OutputFile(fname_out, endianess=endianess)
# meta data
+ of.add_meta_model(params, metadata)
of.add_meta_arch(params)
if isinstance(vocab, Vocab):
of.add_meta_vocab(vocab)
raise ValueError(f"Unexpected combination of types: {name_to_type}")
+def model_parameter_count(model: LazyModel) -> int:
+ total_model_parameters = 0
+ for i, (name, lazy_tensor) in enumerate(model.items()):
+ sum_weights_in_tensor = 1
+ for dim in lazy_tensor.shape:
+ sum_weights_in_tensor *= dim
+ total_model_parameters += sum_weights_in_tensor
+ return total_model_parameters
+
+
+def model_parameter_count_rounded_notation(model_params_count: int) -> str:
+ if model_params_count > 1e12 :
+ # Trillions Of Parameters
+ scaled_model_params = model_params_count * 1e-12
+ scale_suffix = "T"
+ elif model_params_count > 1e9 :
+ # Billions Of Parameters
+ scaled_model_params = model_params_count * 1e-9
+ scale_suffix = "B"
+ elif model_params_count > 1e6 :
+ # Millions Of Parameters
+ scaled_model_params = model_params_count * 1e-6
+ scale_suffix = "M"
+ else:
+ # Thousands Of Parameters
+ scaled_model_params = model_params_count * 1e-3
+ scale_suffix = "K"
+
+ return f"{round(scaled_model_params)}{scale_suffix}"
+
+
def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
return {name: tensor.astype(output_type.type_for_tensor(name, tensor))
for (name, tensor) in model.items()}
return vocab, special_vocab
-def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path:
- namestr = {
- GGMLFileType.AllF32: "f32",
- GGMLFileType.MostlyF16: "f16",
- GGMLFileType.MostlyQ8_0:"q8_0",
+def default_convention_outfile(file_type: GGMLFileType, params: Params, model_params_count: int, metadata: Metadata) -> str:
+ quantization = {
+ GGMLFileType.AllF32: "F32",
+ GGMLFileType.MostlyF16: "F16",
+ GGMLFileType.MostlyQ8_0: "Q8_0",
}[file_type]
- ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf"
+
+ parameters = model_parameter_count_rounded_notation(model_params_count)
+
+ expert_count = ""
+ if params.n_experts is not None:
+ expert_count = f"{params.n_experts}x"
+
+ version = ""
+ if metadata is not None and metadata.version is not None:
+ version = f"-{metadata.version}"
+
+ name = "ggml-model"
+ if metadata is not None and metadata.name is not None:
+ name = metadata.name
+ elif params.path_model is not None:
+ name = params.path_model.name
+
+ return f"{name}{version}-{expert_count}{parameters}-{quantization}"
+
+
+def default_outfile(model_paths: list[Path], file_type: GGMLFileType, params: Params, model_params_count: int, metadata: Metadata) -> Path:
+ default_filename = default_convention_outfile(file_type, params, model_params_count, metadata)
+ ret = model_paths[0].parent / f"{default_filename}.gguf"
if ret in model_paths:
logger.error(
f"Error: Default output path ({ret}) would overwrite the input. "
parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
+ parser.add_argument("--metadata", type=Path, help="Specify the path for a metadata file")
+ parser.add_argument("--get-outfile", action="store_true", help="get calculated default outfile name")
args = parser.parse_args(args_in)
if args.verbose:
logging.basicConfig(level=logging.DEBUG)
- elif args.dump_single or args.dump:
+ elif args.dump_single or args.dump or args.get_outfile:
# Avoid printing anything besides the dump output
logging.basicConfig(level=logging.WARNING)
else:
logging.basicConfig(level=logging.INFO)
+ metadata = Metadata.load(args.metadata)
+
+ if args.get_outfile:
+ model_plus = load_some_model(args.model)
+ params = Params.load(model_plus)
+ model = convert_model_names(model_plus.model, params, args.skip_unknown)
+ model_params_count = model_parameter_count(model_plus.model)
+ ftype = pick_output_type(model, args.outtype)
+ print(f"{default_convention_outfile(ftype, params, model_params_count, metadata)}") # noqa: NP100
+ return
+
if args.no_vocab and args.vocab_only:
raise ValueError("--vocab-only does not make sense with --no-vocab")
else:
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
+ model_params_count = model_parameter_count(model_plus.model)
+ logger.info(f"model parameters count : {model_params_count} ({model_parameter_count_rounded_notation(model_params_count)})")
+
if args.dump:
do_dump_model(model_plus)
return
f_norm_eps = 1e-5,
)
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab,
- endianess=endianess, pad_vocab=args.pad_vocab)
+ endianess=endianess, pad_vocab=args.pad_vocab, metadata=metadata)
logger.info(f"Wrote {outfile}")
return
model = convert_model_names(model, params, args.skip_unknown)
ftype = pick_output_type(model, args.outtype)
model = convert_to_output_type(model, ftype)
- outfile = args.outfile or default_outfile(model_plus.paths, ftype)
+ outfile = args.outfile or default_outfile(model_plus.paths, ftype, params, model_params_count, metadata)
params.ftype = ftype
logger.info(f"Writing {outfile}, format {ftype}")
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
- concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab)
+ concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab, metadata=metadata)
logger.info(f"Wrote {outfile}")