]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
convert.py : advanced option (#2753)
authorKerfuffle <redacted>
Sat, 26 Aug 2023 20:13:36 +0000 (14:13 -0600)
committerGitHub <redacted>
Sat, 26 Aug 2023 20:13:36 +0000 (23:13 +0300)
* Allow convert.py to convert to q8_0

Fix issue with bounded_parallel_map and greedy consuming iterator

Display elapsed time during conversion

* Add --concurrency option

Minor improvements to help text

Clean up bounded_parallel_map function a bit

* Massive speed improvement thanks to Cebtenzzre

* Refactor types

convert.py

index d44e5a8c48d9d4540607c1bd83626ad1348b7c5a..a15e6ccd2367e9eed76e291043ea06718ef75998 100755 (executable)
@@ -3,6 +3,7 @@
 import gguf
 import argparse
 import concurrent.futures
+from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
 import copy
 import enum
 import faulthandler
@@ -17,13 +18,14 @@ import re
 import signal
 import struct
 import sys
+import time
 import zipfile
 import numpy as np
 
 from abc import ABCMeta, abstractmethod
 from dataclasses import dataclass
 from pathlib import Path
-from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar, Union)
+from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, TypeVar, Union)
 from sentencepiece import SentencePieceProcessor  # type: ignore
 
 if TYPE_CHECKING:
@@ -37,30 +39,70 @@ NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
 ARCH=gguf.MODEL_ARCH.LLAMA
 NAMES=gguf.MODEL_TENSOR_NAMES[ARCH]
 
+DEFAULT_CONCURRENCY = 8
 #
 # data types
 #
 
 @dataclass(frozen=True)
-class UnquantizedDataType:
+class DataType:
     name: str
+    dtype: 'np.dtype[Any]'
+    valid_conversions: List[str]
 
-DT_F16  = UnquantizedDataType('F16')
-DT_F32  = UnquantizedDataType('F32')
-DT_I32  = UnquantizedDataType('I32')
-DT_BF16 = UnquantizedDataType('BF16')
+    def elements_to_bytes(self, n_elements: int) -> int:
+        return n_elements * self.dtype.itemsize
 
-DataType = Union[UnquantizedDataType]
+@dataclass(frozen=True)
+class UnquantizedDataType(DataType):
+    pass
 
-DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
-    DT_BF16: np.dtype(np.uint16),
-    DT_F16:  np.dtype(np.float16),
-    DT_F32:  np.dtype(np.float32),
-    DT_I32:  np.dtype(np.int32),
-}
+DT_F16  = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
+DT_F32  = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
+DT_I32  = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
+DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
+
+@dataclass(frozen=True)
+class QuantizedDataType(DataType):
+    block_size: int
+    quantized_dtype: 'np.dtype[Any]'
+    ggml_type: gguf.GGMLQuantizationType
 
-NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = \
-    {dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()}
+    def quantize(self, arr: NDArray) -> NDArray:
+        raise NotImplementedError(f'Quantization for {self.name} not implemented')
+
+    def elements_to_bytes(self, n_elements: int) -> int:
+        assert n_elements % self.block_size == 0, f'Invalid number of elements {n_elements} for {self.name} with block size {self.block_size}'
+        return self.quantized_dtype.itemsize * (n_elements // self.block_size)
+
+@dataclass(frozen=True)
+class Q8_0QuantizedDataType(QuantizedDataType):
+    # Mini Q8_0 quantization in Python!
+    def quantize(self, arr: NDArray) -> NDArray:
+        assert arr.size % self.block_size == 0 and arr.size != 0, f'Bad array size {arr.size}'
+        assert arr.dtype == np.float32, f'Bad array type {arr.dtype}'
+        n_blocks = arr.size // self.block_size
+        blocks = arr.reshape((n_blocks, self.block_size))
+        # Much faster implementation of block quantization contributed by @Cebtenzzre
+        def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[Tuple[Any, Any]]:
+            d = abs(blocks).max(axis = 1) / np.float32(127)
+            with np.errstate(divide = 'ignore'):
+                qs = (blocks / d[:, None]).round()
+            qs[d == 0] = 0
+            yield from zip(d, qs)
+        return np.fromiter(quantize_blocks_q8_0(blocks), count = n_blocks, dtype = self.quantized_dtype)
+
+DT_Q8_0 = Q8_0QuantizedDataType('Q8_0',
+    dtype = np.dtype(np.float32), valid_conversions = [],
+    ggml_type = gguf.GGMLQuantizationType.Q8_0, block_size = 32,
+    quantized_dtype = np.dtype([('d', '<f2'), ('qs', 'i1', (32,))]))
+
+# Quantized types skipped here because they may also map to np.float32
+NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = {}
+for dt in (DT_BF16, DT_F16, DT_F32, DT_I32):
+    if dt.dtype in NUMPY_TYPE_TO_DATA_TYPE:
+        raise ValueError(f'Invalid duplicate data type {dt}')
+    NUMPY_TYPE_TO_DATA_TYPE[dt.dtype] = dt
 
 SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
     'BF16': DT_BF16,
@@ -73,20 +115,22 @@ SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
 # TODO: rename to LLAMAFileType
 # TODO: move to `gguf.py`
 class GGMLFileType(enum.IntEnum):
-    AllF32    = 0
-    MostlyF16 = 1  # except 1d tensors
+    AllF32     = 0
+    MostlyF16  = 1  # except 1d tensors
+    MostlyQ8_0 = 7  # except 1d tensors
 
     def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType:
-        if len(tensor.shape) == 1:
-            # 1D tensors are always F32.
-            return DT_F32
-        elif self == GGMLFileType.AllF32:
-            return DT_F32
-        elif self == GGMLFileType.MostlyF16:
-            return DT_F16
-        else:
+        dt = GGML_FILE_TYPE_TO_DATA_TYPE.get(self)
+        if dt is None:
             raise ValueError(self)
+        # 1D tensors are always F32.
+        return dt if len(tensor.shape) > 1 else DT_F32
 
+GGML_FILE_TYPE_TO_DATA_TYPE: Dict[GGMLFileType, DataType] = {
+    GGMLFileType.AllF32    : DT_F32,
+    GGMLFileType.MostlyF16 : DT_F16,
+    GGMLFileType.MostlyQ8_0: DT_Q8_0,
+}
 
 #
 # hparams loading
@@ -415,7 +459,7 @@ class UnquantizedTensor(Tensor):
         self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]
 
     def astype(self, data_type: DataType) -> Tensor:
-        dtype = DATA_TYPE_TO_NUMPY[data_type]
+        dtype = data_type.dtype
         if self.data_type == DT_BF16:
             self.ndarray = bf16_to_fp32(self.ndarray)
         return UnquantizedTensor(self.ndarray.astype(dtype))
@@ -454,22 +498,6 @@ def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, conv
 GGMLCompatibleTensor = Union[UnquantizedTensor]
 
 
-class DeferredPermutedTensor(Tensor):
-    def __init__(self, base: Tensor, n_head: int, n_head_kv: int) -> None:
-        self.base = base
-        self.n_head = n_head
-        self.data_type = self.base.data_type
-
-    def astype(self, data_type: DataType) -> Tensor:
-        return self.base.astype(data_type).permute(self.n_head, self.n_head_kv)
-
-    def to_ggml(self) -> GGMLCompatibleTensor:
-        return self.base.to_ggml().permute(self.n_head, self.n_head_kv)
-
-    def permute(self, n_head: int, n_head_kv: int) -> Tensor:
-        raise Exception("shouldn't permute twice")
-
-
 @dataclass
 class LazyTensor:
     _load: Callable[[], Tensor]
@@ -479,7 +507,9 @@ class LazyTensor:
 
     def load(self) -> Tensor:
         ret = self._load()
-        assert ret.data_type == self.data_type, (self.data_type, ret.data_type, self.description)
+        # Should be okay if it maps to the same numpy type?
+        assert ret.data_type == self.data_type or (self.data_type.dtype == ret.data_type.dtype), \
+                (self.data_type, ret.data_type, self.description)
         return ret
 
     def astype(self, data_type: DataType) -> 'LazyTensor':
@@ -490,8 +520,8 @@ class LazyTensor:
         return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}')
 
     def validate_conversion_to(self, data_type: DataType) -> None:
-        if data_type == self.data_type:
-            return
+        if data_type != self.data_type and data_type.name not in self.data_type.valid_conversions:
+            raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.')
 
 
 LazyModel = Dict[str, LazyTensor]
@@ -617,9 +647,7 @@ class LazyUnpickler(pickle.Unpickler):
         info = self.zip_file.getinfo(filename)
 
         def load(offset: int, elm_count: int) -> NDArray:
-            dtype = DATA_TYPE_TO_NUMPY.get(data_type)
-            if dtype is None:
-                raise Exception("tensor stored in unsupported format")
+            dtype = data_type.dtype
             fp = self.zip_file.open(info)
             fp.seek(offset * dtype.itemsize)
             size = elm_count * dtype.itemsize
@@ -683,7 +711,7 @@ def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
 
     def convert(info: Dict[str, Any]) -> LazyTensor:
         data_type = SAFETENSORS_DATA_TYPES[info['dtype']]
-        numpy_dtype = DATA_TYPE_TO_NUMPY[data_type]
+        numpy_dtype = data_type.dtype
         shape: List[int] = info['shape']
         begin, end = info['data_offsets']
         assert 0 <= begin <= end <= len(byte_buf)
@@ -723,23 +751,35 @@ def lazy_load_file(path: Path) -> ModelPlus:
 In = TypeVar('In')
 Out = TypeVar('Out')
 
-def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int) -> Iterable[Out]:
+def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, factory: Callable = ThreadPoolExecutor) -> Iterable[Out]:
     '''Parallel map, but with backpressure.  If the caller doesn't call `next`
     fast enough, this will stop calling `func` at some point rather than
     letting results pile up in memory.  Specifically, there is a max of one
     output value buffered per thread.'''
-    with concurrent.futures.ThreadPoolExecutor() as executor:
+    if concurrency < 2:
+        yield from map(func, iterable)
+        # Not reached.
+    iterable = iter(iterable)
+    with factory(max_workers = max_workers) as executor:
         futures: List[concurrent.futures.Future[Out]] = []
-        items_rev = list(iterable)[::-1]
-        for i in range(min(concurrency, len(items_rev))):
-            futures.append(executor.submit(func, items_rev.pop()))
+        done = False
+        for _ in range(concurrency):
+            try:
+                futures.append(executor.submit(func, next(iterable)))
+            except StopIteration:
+                done = True
+                break
+
         while futures:
             result = futures.pop(0).result()
-            if items_rev:
-                futures.append(executor.submit(func, items_rev.pop()))
+            while not done and len(futures) < concurrency:
+                try:
+                    futures.append(executor.submit(func, next(iterable)))
+                except StopIteration:
+                    done = True
+                    break
             yield result
 
-
 def check_vocab_size(params: Params, vocab: Vocab) -> None:
     if params.n_vocab != vocab.vocab_size:
         assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab)
@@ -804,12 +844,11 @@ class OutputFile:
         self.gguf.add_token_types(toktypes)
 
     def add_tensor_info(self, name: str, tensor: LazyTensor) -> None:
-        n_elements = 1
-        for dim in tensor.shape:
-            n_elements *= dim
-        data_type = DATA_TYPE_TO_NUMPY[tensor.data_type]
-        data_nbytes = n_elements * data_type.itemsize
-        self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes)
+        n_elements = int(np.prod(tensor.shape))
+        raw_dtype = getattr(tensor.data_type, 'ggml_type', None)
+        data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype
+        data_nbytes = tensor.data_type.elements_to_bytes(n_elements)
+        self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype = raw_dtype)
 
     def write_meta(self) -> None:
         self.gguf.write_header_to_file()
@@ -835,7 +874,20 @@ class OutputFile:
         of.close()
 
     @staticmethod
-    def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None:
+    def do_item(item: Tuple[str, LazyTensor]) -> Tuple[DataType, NDArray]:
+        name, lazy_tensor = item
+        tensor = lazy_tensor.load().to_ggml()
+        return (lazy_tensor.data_type, tensor.ndarray)
+
+    @staticmethod
+    def maybe_do_quantize(item: Tuple[DataType, NDArray]) -> NDArray:
+        dt, arr = item
+        if not isinstance(dt, QuantizedDataType):
+            return arr
+        return dt.quantize(arr)
+
+    @staticmethod
+    def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, concurrency: int = DEFAULT_CONCURRENCY) -> None:
         check_vocab_size(params, vocab)
 
         of = OutputFile(fname_out)
@@ -851,16 +903,19 @@ class OutputFile:
         of.write_meta()
         of.write_tensor_info()
 
-        def do_item(item: Tuple[str, LazyTensor]) -> NDArray:
-            name, lazy_tensor = item
-            return lazy_tensor.load().to_ggml().ndarray
-
         # tensor data
-        ndarrays = bounded_parallel_map(do_item, model.items(), concurrency=8)
+        ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
+        if ftype == GGMLFileType.MostlyQ8_0:
+            ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor)
+        else:
+            ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
+
+        start = time.time()
         for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
+            elapsed = time.time() - start
             size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
             padi = len(str(len(model)))
-            print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type}")
+            print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}")
             of.gguf.write_tensor_data(ndarray)
 
         of.close()
@@ -872,6 +927,8 @@ def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFi
         return GGMLFileType.AllF32
     if output_type_str == "f16" or (output_type_str is None and wq_type in (DT_F16, DT_BF16)):
         return GGMLFileType.MostlyF16
+    if output_type_str == "q8_0":
+        return GGMLFileType.MostlyQ8_0
 
     name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}
 
@@ -918,7 +975,7 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
             print(f"skipping tensor {name_new}")
             continue
         else:
-            print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type} | {lazy_tensor.shape}")
+            print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}")
             out[name_new] = lazy_tensor
 
     return out
@@ -1023,6 +1080,7 @@ def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path:
     namestr = {
         GGMLFileType.AllF32:    "f32",
         GGMLFileType.MostlyF16: "f16",
+        GGMLFileType.MostlyQ8_0:"q8_0",
     }[file_type]
     ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf"
     if ret in model_paths:
@@ -1046,12 +1104,13 @@ def main(args_in: Optional[List[str]] = None) -> None:
     parser.add_argument("--dump",        action="store_true",    help="don't convert, just show what's in the model")
     parser.add_argument("--dump-single", action="store_true",    help="don't convert, just show what's in a single model file")
     parser.add_argument("--vocab-only",  action="store_true",    help="extract only the vocab")
-    parser.add_argument("--outtype",     choices=["f32", "f16"], help="output format (default: based on input)")
+    parser.add_argument("--outtype",     choices=["f32", "f16", "q8_0"], help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
     parser.add_argument("--vocab-dir",   type=Path,              help="directory containing tokenizer.model, if separate from model file")
     parser.add_argument("--outfile",     type=Path,              help="path to write to; default: based on input")
     parser.add_argument("model",         type=Path,              help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
     parser.add_argument("--vocabtype",   choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
     parser.add_argument("--ctx",         type=int,               help="model training context (default: based on input)")
+    parser.add_argument("--concurrency", type=int,               help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
     args = parser.parse_args(args_in)
 
     if args.dump_single:
@@ -1073,6 +1132,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
         params.ftype = {
             "f32": GGMLFileType.AllF32,
             "f16": GGMLFileType.MostlyF16,
+            "q8_0": GGMLFileType.MostlyQ8_0,
         }[args.outtype]
 
     print(f"params = {params}")
@@ -1104,7 +1164,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
         params.ftype = ftype
         print(f"Writing {outfile}, format {ftype}")
 
-        OutputFile.write_all(outfile, params, model, vocab)
+        OutputFile.write_all(outfile, ftype, params, model, vocab, concurrency = args.concurrency)
         print(f"Wrote {outfile}")