import gguf
import argparse
import concurrent.futures
+from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import copy
import enum
import faulthandler
import signal
import struct
import sys
+import time
import zipfile
import numpy as np
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from pathlib import Path
-from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar, Union)
+from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, TypeVar, Union)
from sentencepiece import SentencePieceProcessor # type: ignore
if TYPE_CHECKING:
ARCH=gguf.MODEL_ARCH.LLAMA
NAMES=gguf.MODEL_TENSOR_NAMES[ARCH]
+DEFAULT_CONCURRENCY = 8
#
# data types
#
@dataclass(frozen=True)
-class UnquantizedDataType:
+class DataType:
name: str
+ dtype: 'np.dtype[Any]'
+ valid_conversions: List[str]
-DT_F16 = UnquantizedDataType('F16')
-DT_F32 = UnquantizedDataType('F32')
-DT_I32 = UnquantizedDataType('I32')
-DT_BF16 = UnquantizedDataType('BF16')
+ def elements_to_bytes(self, n_elements: int) -> int:
+ return n_elements * self.dtype.itemsize
-DataType = Union[UnquantizedDataType]
+@dataclass(frozen=True)
+class UnquantizedDataType(DataType):
+ pass
-DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
- DT_BF16: np.dtype(np.uint16),
- DT_F16: np.dtype(np.float16),
- DT_F32: np.dtype(np.float32),
- DT_I32: np.dtype(np.int32),
-}
+DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
+DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
+DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
+DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
+
+@dataclass(frozen=True)
+class QuantizedDataType(DataType):
+ block_size: int
+ quantized_dtype: 'np.dtype[Any]'
+ ggml_type: gguf.GGMLQuantizationType
-NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = \
- {dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()}
+ def quantize(self, arr: NDArray) -> NDArray:
+ raise NotImplementedError(f'Quantization for {self.name} not implemented')
+
+ def elements_to_bytes(self, n_elements: int) -> int:
+ assert n_elements % self.block_size == 0, f'Invalid number of elements {n_elements} for {self.name} with block size {self.block_size}'
+ return self.quantized_dtype.itemsize * (n_elements // self.block_size)
+
+@dataclass(frozen=True)
+class Q8_0QuantizedDataType(QuantizedDataType):
+ # Mini Q8_0 quantization in Python!
+ def quantize(self, arr: NDArray) -> NDArray:
+ assert arr.size % self.block_size == 0 and arr.size != 0, f'Bad array size {arr.size}'
+ assert arr.dtype == np.float32, f'Bad array type {arr.dtype}'
+ n_blocks = arr.size // self.block_size
+ blocks = arr.reshape((n_blocks, self.block_size))
+ # Much faster implementation of block quantization contributed by @Cebtenzzre
+ def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[Tuple[Any, Any]]:
+ d = abs(blocks).max(axis = 1) / np.float32(127)
+ with np.errstate(divide = 'ignore'):
+ qs = (blocks / d[:, None]).round()
+ qs[d == 0] = 0
+ yield from zip(d, qs)
+ return np.fromiter(quantize_blocks_q8_0(blocks), count = n_blocks, dtype = self.quantized_dtype)
+
+DT_Q8_0 = Q8_0QuantizedDataType('Q8_0',
+ dtype = np.dtype(np.float32), valid_conversions = [],
+ ggml_type = gguf.GGMLQuantizationType.Q8_0, block_size = 32,
+ quantized_dtype = np.dtype([('d', '<f2'), ('qs', 'i1', (32,))]))
+
+# Quantized types skipped here because they may also map to np.float32
+NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = {}
+for dt in (DT_BF16, DT_F16, DT_F32, DT_I32):
+ if dt.dtype in NUMPY_TYPE_TO_DATA_TYPE:
+ raise ValueError(f'Invalid duplicate data type {dt}')
+ NUMPY_TYPE_TO_DATA_TYPE[dt.dtype] = dt
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
'BF16': DT_BF16,
# TODO: rename to LLAMAFileType
# TODO: move to `gguf.py`
class GGMLFileType(enum.IntEnum):
- AllF32 = 0
- MostlyF16 = 1 # except 1d tensors
+ AllF32 = 0
+ MostlyF16 = 1 # except 1d tensors
+ MostlyQ8_0 = 7 # except 1d tensors
def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType:
- if len(tensor.shape) == 1:
- # 1D tensors are always F32.
- return DT_F32
- elif self == GGMLFileType.AllF32:
- return DT_F32
- elif self == GGMLFileType.MostlyF16:
- return DT_F16
- else:
+ dt = GGML_FILE_TYPE_TO_DATA_TYPE.get(self)
+ if dt is None:
raise ValueError(self)
+ # 1D tensors are always F32.
+ return dt if len(tensor.shape) > 1 else DT_F32
+GGML_FILE_TYPE_TO_DATA_TYPE: Dict[GGMLFileType, DataType] = {
+ GGMLFileType.AllF32 : DT_F32,
+ GGMLFileType.MostlyF16 : DT_F16,
+ GGMLFileType.MostlyQ8_0: DT_Q8_0,
+}
#
# hparams loading
self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]
def astype(self, data_type: DataType) -> Tensor:
- dtype = DATA_TYPE_TO_NUMPY[data_type]
+ dtype = data_type.dtype
if self.data_type == DT_BF16:
self.ndarray = bf16_to_fp32(self.ndarray)
return UnquantizedTensor(self.ndarray.astype(dtype))
GGMLCompatibleTensor = Union[UnquantizedTensor]
-class DeferredPermutedTensor(Tensor):
- def __init__(self, base: Tensor, n_head: int, n_head_kv: int) -> None:
- self.base = base
- self.n_head = n_head
- self.data_type = self.base.data_type
-
- def astype(self, data_type: DataType) -> Tensor:
- return self.base.astype(data_type).permute(self.n_head, self.n_head_kv)
-
- def to_ggml(self) -> GGMLCompatibleTensor:
- return self.base.to_ggml().permute(self.n_head, self.n_head_kv)
-
- def permute(self, n_head: int, n_head_kv: int) -> Tensor:
- raise Exception("shouldn't permute twice")
-
-
@dataclass
class LazyTensor:
_load: Callable[[], Tensor]
def load(self) -> Tensor:
ret = self._load()
- assert ret.data_type == self.data_type, (self.data_type, ret.data_type, self.description)
+ # Should be okay if it maps to the same numpy type?
+ assert ret.data_type == self.data_type or (self.data_type.dtype == ret.data_type.dtype), \
+ (self.data_type, ret.data_type, self.description)
return ret
def astype(self, data_type: DataType) -> 'LazyTensor':
return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}')
def validate_conversion_to(self, data_type: DataType) -> None:
- if data_type == self.data_type:
- return
+ if data_type != self.data_type and data_type.name not in self.data_type.valid_conversions:
+ raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.')
LazyModel = Dict[str, LazyTensor]
info = self.zip_file.getinfo(filename)
def load(offset: int, elm_count: int) -> NDArray:
- dtype = DATA_TYPE_TO_NUMPY.get(data_type)
- if dtype is None:
- raise Exception("tensor stored in unsupported format")
+ dtype = data_type.dtype
fp = self.zip_file.open(info)
fp.seek(offset * dtype.itemsize)
size = elm_count * dtype.itemsize
def convert(info: Dict[str, Any]) -> LazyTensor:
data_type = SAFETENSORS_DATA_TYPES[info['dtype']]
- numpy_dtype = DATA_TYPE_TO_NUMPY[data_type]
+ numpy_dtype = data_type.dtype
shape: List[int] = info['shape']
begin, end = info['data_offsets']
assert 0 <= begin <= end <= len(byte_buf)
In = TypeVar('In')
Out = TypeVar('Out')
-def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int) -> Iterable[Out]:
+def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, factory: Callable = ThreadPoolExecutor) -> Iterable[Out]:
'''Parallel map, but with backpressure. If the caller doesn't call `next`
fast enough, this will stop calling `func` at some point rather than
letting results pile up in memory. Specifically, there is a max of one
output value buffered per thread.'''
- with concurrent.futures.ThreadPoolExecutor() as executor:
+ if concurrency < 2:
+ yield from map(func, iterable)
+ # Not reached.
+ iterable = iter(iterable)
+ with factory(max_workers = max_workers) as executor:
futures: List[concurrent.futures.Future[Out]] = []
- items_rev = list(iterable)[::-1]
- for i in range(min(concurrency, len(items_rev))):
- futures.append(executor.submit(func, items_rev.pop()))
+ done = False
+ for _ in range(concurrency):
+ try:
+ futures.append(executor.submit(func, next(iterable)))
+ except StopIteration:
+ done = True
+ break
+
while futures:
result = futures.pop(0).result()
- if items_rev:
- futures.append(executor.submit(func, items_rev.pop()))
+ while not done and len(futures) < concurrency:
+ try:
+ futures.append(executor.submit(func, next(iterable)))
+ except StopIteration:
+ done = True
+ break
yield result
-
def check_vocab_size(params: Params, vocab: Vocab) -> None:
if params.n_vocab != vocab.vocab_size:
assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab)
self.gguf.add_token_types(toktypes)
def add_tensor_info(self, name: str, tensor: LazyTensor) -> None:
- n_elements = 1
- for dim in tensor.shape:
- n_elements *= dim
- data_type = DATA_TYPE_TO_NUMPY[tensor.data_type]
- data_nbytes = n_elements * data_type.itemsize
- self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes)
+ n_elements = int(np.prod(tensor.shape))
+ raw_dtype = getattr(tensor.data_type, 'ggml_type', None)
+ data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype
+ data_nbytes = tensor.data_type.elements_to_bytes(n_elements)
+ self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype = raw_dtype)
def write_meta(self) -> None:
self.gguf.write_header_to_file()
of.close()
@staticmethod
- def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None:
+ def do_item(item: Tuple[str, LazyTensor]) -> Tuple[DataType, NDArray]:
+ name, lazy_tensor = item
+ tensor = lazy_tensor.load().to_ggml()
+ return (lazy_tensor.data_type, tensor.ndarray)
+
+ @staticmethod
+ def maybe_do_quantize(item: Tuple[DataType, NDArray]) -> NDArray:
+ dt, arr = item
+ if not isinstance(dt, QuantizedDataType):
+ return arr
+ return dt.quantize(arr)
+
+ @staticmethod
+ def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, concurrency: int = DEFAULT_CONCURRENCY) -> None:
check_vocab_size(params, vocab)
of = OutputFile(fname_out)
of.write_meta()
of.write_tensor_info()
- def do_item(item: Tuple[str, LazyTensor]) -> NDArray:
- name, lazy_tensor = item
- return lazy_tensor.load().to_ggml().ndarray
-
# tensor data
- ndarrays = bounded_parallel_map(do_item, model.items(), concurrency=8)
+ ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
+ if ftype == GGMLFileType.MostlyQ8_0:
+ ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor)
+ else:
+ ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
+
+ start = time.time()
for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
+ elapsed = time.time() - start
size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
padi = len(str(len(model)))
- print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type}")
+ print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}")
of.gguf.write_tensor_data(ndarray)
of.close()
return GGMLFileType.AllF32
if output_type_str == "f16" or (output_type_str is None and wq_type in (DT_F16, DT_BF16)):
return GGMLFileType.MostlyF16
+ if output_type_str == "q8_0":
+ return GGMLFileType.MostlyQ8_0
name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}
print(f"skipping tensor {name_new}")
continue
else:
- print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type} | {lazy_tensor.shape}")
+ print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}")
out[name_new] = lazy_tensor
return out
namestr = {
GGMLFileType.AllF32: "f32",
GGMLFileType.MostlyF16: "f16",
+ GGMLFileType.MostlyQ8_0:"q8_0",
}[file_type]
ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf"
if ret in model_paths:
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
- parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)")
+ parser.add_argument("--outtype", choices=["f32", "f16", "q8_0"], help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
+ parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
args = parser.parse_args(args_in)
if args.dump_single:
params.ftype = {
"f32": GGMLFileType.AllF32,
"f16": GGMLFileType.MostlyF16,
+ "q8_0": GGMLFileType.MostlyQ8_0,
}[args.outtype]
print(f"params = {params}")
params.ftype = ftype
print(f"Writing {outfile}, format {ftype}")
- OutputFile.write_all(outfile, params, model, vocab)
+ OutputFile.write_all(outfile, ftype, params, model, vocab, concurrency = args.concurrency)
print(f"Wrote {outfile}")