]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
gguf-py : simplify support for quant types (#8838)
authorcompilade <redacted>
Thu, 8 Aug 2024 17:33:09 +0000 (13:33 -0400)
committerGitHub <redacted>
Thu, 8 Aug 2024 17:33:09 +0000 (13:33 -0400)
* gguf-py : use classes for quants

* convert_hf : simplify internal quantization type selection

* gguf-py : fix flake8 lint

* gguf-py : fix BF16 numpy view type

* gguf-py : remove LlamaFileTypeMap

Too specific to 'llama.cpp', and would be a maintenance burden
to keep up to date.

* gguf-py : add generic quantize and dequantize functions

The quant classes no longer need to be known,
only the target or the source type,
for 'quantize' and 'dequantize', respectively.

convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/lazy.py
gguf-py/gguf/quants.py

index 38b92bc8110eac1ce8d3cc5b581aa45bd84fe6d1..7136db440644b4a6e6c81c3e28694039e2a13cf5 100755 (executable)
@@ -251,12 +251,7 @@ class Model:
 
         return [(self.map_tensor_name(name), data_torch)]
 
-    def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
-        del name, new_name, bid, n_dims  # unused
-
-        return False
-
-    def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
+    def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
         del name, new_name, bid, n_dims  # unused
 
         return False
@@ -285,54 +280,46 @@ class Model:
             for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
                 data: np.ndarray  # type hint
                 n_dims = len(data.shape)
-                data_dtype = data.dtype
-                data_qtype: gguf.GGMLQuantizationType | None = None
-
-                # when both are True, f32 should win
-                extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
-                extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims)
+                data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims)
 
                 # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors
-                # Conditions should closely match those in llama_model_quantize_internal in llama.cpp
-                extra_f32 = any(cond for cond in (
-                    extra_f32,
-                    n_dims == 1,
-                    new_name.endswith("_norm.weight"),
-                ))
+                if n_dims <= 1 or new_name.endswith("_norm.weight"):
+                    data_qtype = gguf.GGMLQuantizationType.F32
 
+                # Conditions should closely match those in llama_model_quantize_internal in llama.cpp
                 # Some tensor types are always in float32
-                extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in (
-                    gguf.MODEL_TENSOR.FFN_GATE_INP,
-                    gguf.MODEL_TENSOR.POS_EMBD,
-                    gguf.MODEL_TENSOR.TOKEN_TYPES,
-                ))
-
-                # if f16 desired, convert any float32 2-dim weight tensors to float16
-                extra_f16 = any(cond for cond in (
-                    extra_f16,
-                    (name.endswith(".weight") and n_dims >= 2),
-                ))
-
-                if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
-                    if self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
-                        data = gguf.quantize_bf16(data)
-                        assert data.dtype == np.uint16
-                        data_qtype = gguf.GGMLQuantizationType.BF16
-
-                    elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data):
-                        data = gguf.quantize_q8_0(data)
-                        assert data.dtype == np.uint8
-                        data_qtype = gguf.GGMLQuantizationType.Q8_0
+                if data_qtype is False and (
+                    any(
+                        self.match_model_tensor_name(new_name, key, bid)
+                        for key in (
+                            gguf.MODEL_TENSOR.FFN_GATE_INP,
+                            gguf.MODEL_TENSOR.POS_EMBD,
+                            gguf.MODEL_TENSOR.TOKEN_TYPES,
+                        )
+                    )
+                    or not name.endswith(".weight")
+                ):
+                    data_qtype = gguf.GGMLQuantizationType.F32
 
-                    else:  # default to float16 for quantized tensors
-                        if data_dtype != np.float16:
-                            data = data.astype(np.float16)
+                # No override (data_qtype is False), or wants to be quantized (data_qtype is True)
+                if isinstance(data_qtype, bool):
+                    if self.ftype == gguf.LlamaFileType.ALL_F32:
+                        data_qtype = gguf.GGMLQuantizationType.F32
+                    elif self.ftype == gguf.LlamaFileType.MOSTLY_F16:
                         data_qtype = gguf.GGMLQuantizationType.F16
+                    elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
+                        data_qtype = gguf.GGMLQuantizationType.BF16
+                    elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
+                        data_qtype = gguf.GGMLQuantizationType.Q8_0
+                    else:
+                        raise ValueError(f"Unknown file type: {self.ftype.name}")
 
-                if data_qtype is None:  # by default, convert to float32
-                    if data_dtype != np.float32:
-                        data = data.astype(np.float32)
-                    data_qtype = gguf.GGMLQuantizationType.F32
+                try:
+                    data = gguf.quants.quantize(data, data_qtype)
+                except gguf.QuantError as e:
+                    logger.warning("%s, %s", e, "falling back to F16")
+                    data_qtype = gguf.GGMLQuantizationType.F16
+                    data = gguf.quants.quantize(data, data_qtype)
 
                 shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape
 
@@ -1765,7 +1752,7 @@ class DbrxModel(Model):
 
         return [(new_name, data_torch)]
 
-    def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
+    def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
         del name, new_name, bid  # unused
 
         return n_dims > 1
@@ -2786,18 +2773,22 @@ class MambaModel(Model):
 
         return [(new_name, data_torch)]
 
-    def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
-        del n_dims  # unused
-
-        return bid is not None and new_name in (
-            self.format_tensor_name(n, bid, ".weight" if name.endswith(".weight") else "") for n in [
+    def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
+        if bid is not None and new_name in (
+            self.format_tensor_name(
+                n, bid, ".weight" if name.endswith(".weight") else ""
+            )
+            for n in [
                 gguf.MODEL_TENSOR.SSM_CONV1D,
                 gguf.MODEL_TENSOR.SSM_X,
                 gguf.MODEL_TENSOR.SSM_DT,
                 gguf.MODEL_TENSOR.SSM_A,
                 gguf.MODEL_TENSOR.SSM_D,
             ]
-        )
+        ):
+            return gguf.GGMLQuantizationType.F32
+
+        return super().tensor_force_quant(name, new_name, bid, n_dims)
 
 
 @Model.register("CohereForCausalLM")
index 59ffd92ea00cc88b67763c51e8e7062488914c8c..89efe0c800964a22dc9463419d25963fbe66c56c 100644 (file)
@@ -1146,6 +1146,9 @@ class GGMLQuantizationType(IntEnum):
     F64     = 28
     IQ1_M   = 29
     BF16    = 30
+    Q4_0_4_4 = 31
+    Q4_0_4_8 = 32
+    Q4_0_8_8 = 33
 
 
 # TODO: add GGMLFileType from ggml_ftype in ggml.h
@@ -1158,7 +1161,7 @@ class LlamaFileType(IntEnum):
     MOSTLY_F16           = 1   # except 1d tensors
     MOSTLY_Q4_0          = 2   # except 1d tensors
     MOSTLY_Q4_1          = 3   # except 1d tensors
-    MOSTLY_Q4_1_SOME_F16 = 4   # tok_embeddings.weight and output.weight are F16
+    MOSTLY_Q4_1_SOME_F16 = 4   # tok_embeddings.weight and output.weight are F16
     # MOSTLY_Q4_2        = 5   # support has been removed
     # MOSTLY_Q4_3        = 6   # support has been removed
     MOSTLY_Q8_0          = 7   # except 1d tensors
@@ -1187,6 +1190,9 @@ class LlamaFileType(IntEnum):
     MOSTLY_IQ4_XS        = 30  # except 1d tensors
     MOSTLY_IQ1_M         = 31  # except 1d tensors
     MOSTLY_BF16          = 32  # except 1d tensors
+    MOSTLY_Q4_0_4_4      = 33  # except 1d tensors
+    MOSTLY_Q4_0_4_8      = 34  # except 1d tensors
+    MOSTLY_Q4_0_8_8      = 35  # except 1d tensors
 
     GUESSED              = 1024  # not specified in the model file
 
@@ -1260,6 +1266,9 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
     GGMLQuantizationType.F64:     (1, 8),
     GGMLQuantizationType.IQ1_M:   (256, QK_K // 8 + QK_K // 16  + QK_K // 32),
     GGMLQuantizationType.BF16:    (1, 2),
+    GGMLQuantizationType.Q4_0_4_4:(32, 2 + 16),
+    GGMLQuantizationType.Q4_0_4_8:(32, 2 + 16),
+    GGMLQuantizationType.Q4_0_8_8:(32, 2 + 16),
 }
 
 
index ac98d9a92a3e9e34efca19491c0c032e9e3526d0..8d4fece2dca86983286a3c0de15ca86578ce4dfa 100644 (file)
@@ -191,6 +191,8 @@ class LazyBase(ABC, metaclass=LazyMeta):
 class LazyNumpyTensor(LazyBase):
     _tensor_type = np.ndarray
 
+    shape: tuple[int, ...]  # Makes the type checker happy in quants.py
+
     @classmethod
     def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]:
         # The initial idea was to use np.nan as the fill value,
index f4361d75170768dc75f5a3cb2d8514e54673b25a..a443dd27e62b34515d14afae71f2f229833145f5 100644 (file)
@@ -1,5 +1,6 @@
 from __future__ import annotations
-from typing import Callable, Sequence
+from abc import ABC, abstractmethod
+from typing import Any, Callable, Sequence
 
 from numpy.typing import DTypeLike
 
@@ -9,32 +10,22 @@ from .lazy import LazyNumpyTensor
 import numpy as np
 
 
-def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType):
+def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
     block_size, type_size = GGML_QUANT_SIZES[quant_type]
     if shape[-1] % block_size != 0:
         raise ValueError(f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})")
     return (*shape[:-1], shape[-1] // block_size * type_size)
 
 
-def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType):
+def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
     block_size, type_size = GGML_QUANT_SIZES[quant_type]
     if shape[-1] % type_size != 0:
         raise ValueError(f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {quant_type.name} type size ({type_size})")
     return (*shape[:-1], shape[-1] // type_size * block_size)
 
 
-# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
-def __compute_fp32_to_bf16(n: np.ndarray) -> np.ndarray:
-    n = n.astype(np.float32, copy=False).view(np.uint32)
-    # force nan to quiet
-    n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n)
-    # round to nearest even
-    n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16
-    return n.astype(np.uint16)
-
-
 # This is faster than np.vectorize and np.apply_along_axis because it works on more than one row at a time
-def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray:
+def _apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray:
     rows = arr.reshape((-1, arr.shape[-1]))
     osize = 1
     for dim in oshape:
@@ -46,27 +37,6 @@ def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.
     return out.reshape(oshape)
 
 
-def __quantize_bf16_array(n: np.ndarray) -> np.ndarray:
-    return __apply_over_grouped_rows(__compute_fp32_to_bf16, arr=n, otype=np.uint16, oshape=n.shape)
-
-
-__quantize_bf16_lazy = LazyNumpyTensor._wrap_fn(__quantize_bf16_array, meta_noop=np.uint16)
-
-
-def quantize_bf16(n: np.ndarray):
-    if type(n) is LazyNumpyTensor:
-        return __quantize_bf16_lazy(n)
-    else:
-        return __quantize_bf16_array(n)
-
-
-__q8_block_size, __q8_type_size = GGML_QUANT_SIZES[GGMLQuantizationType.Q8_0]
-
-
-def can_quantize_to_q8_0(n: np.ndarray) -> bool:
-    return n.shape[-1] % __q8_block_size == 0
-
-
 # round away from zero
 # ref: https://stackoverflow.com/a/59143326/22827863
 def np_roundf(n: np.ndarray) -> np.ndarray:
@@ -76,46 +46,168 @@ def np_roundf(n: np.ndarray) -> np.ndarray:
     return np.sign(n) * b
 
 
-def __quantize_q8_0_shape_change(s: tuple[int, ...]) -> tuple[int, ...]:
-    return (*s[:-1], s[-1] // __q8_block_size * __q8_type_size)
-
-
-# Implementation of Q8_0 with bit-exact same results as reference implementation in ggml-quants.c
-def __quantize_q8_0_rows(n: np.ndarray) -> np.ndarray:
-    shape = n.shape
-    assert shape[-1] % __q8_block_size == 0
-
-    n_blocks = n.size // __q8_block_size
-
-    blocks = n.reshape((n_blocks, __q8_block_size)).astype(np.float32, copy=False)
+class QuantError(Exception): ...
 
-    d = abs(blocks).max(axis=1, keepdims=True) / 127
-    with np.errstate(divide="ignore"):
-        id = np.where(d == 0, 0, 1 / d)
-    qs = np_roundf(blocks * id)
 
-    # (n_blocks, 2)
-    d = d.astype(np.float16).view(np.uint8)
-    # (n_blocks, block_size)
-    qs = qs.astype(np.int8).view(np.uint8)
+_type_traits: dict[GGMLQuantizationType, type[__Quant]] = {}
 
-    assert d.shape[1] + qs.shape[1] == __q8_type_size
 
-    return np.concatenate([d, qs], axis=1).reshape(__quantize_q8_0_shape_change(shape))
-
-
-def __quantize_q8_0_array(n: np.ndarray) -> np.ndarray:
-    return __apply_over_grouped_rows(__quantize_q8_0_rows, arr=n, otype=np.uint8, oshape=__quantize_q8_0_shape_change(n.shape))
-
-
-__quantize_q8_0_lazy = LazyNumpyTensor._wrap_fn(
-    __quantize_q8_0_array,
-    meta_noop=(np.uint8, __quantize_q8_0_shape_change),
-)
+def quantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
+    if qtype == GGMLQuantizationType.F32:
+        return data.astype(np.float32, copy=False)
+    elif qtype == GGMLQuantizationType.F16:
+        return data.astype(np.float16, copy=False)
+    elif (q := _type_traits.get(qtype)) is not None:
+        return q.quantize(data)
+    else:
+        raise NotImplementedError(f"Quantization for {qtype.name} is not yet implemented")
 
 
-def quantize_q8_0(data: np.ndarray):
-    if type(data) is LazyNumpyTensor:
-        return __quantize_q8_0_lazy(data)
+def dequantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
+    if qtype == GGMLQuantizationType.F32 or qtype == GGMLQuantizationType.F16:
+        return data.astype(np.float32, copy=False)
+    elif (q := _type_traits.get(qtype)) is not None:
+        return q.dequantize(data)
     else:
-        return __quantize_q8_0_array(data)
+        raise NotImplementedError(f"Dequantization for {qtype.name} is not yet implemented")
+
+
+class __Quant(ABC):
+    qtype: GGMLQuantizationType
+    block_size: int
+    type_size: int
+
+    def __init__(self):
+        return TypeError("Quant conversion classes can't have instances")
+
+    def __init_subclass__(cls, qtype: GGMLQuantizationType) -> None:
+        cls.qtype = qtype
+        cls.block_size, cls.type_size = GGML_QUANT_SIZES[qtype]
+        cls.__quantize_lazy = LazyNumpyTensor._wrap_fn(
+            cls.__quantize_array,
+            meta_noop=(np.uint8, cls.__shape_to_bytes)
+        )
+        cls.__dequantize_lazy = LazyNumpyTensor._wrap_fn(
+            cls.__dequantize_array,
+            meta_noop=(np.float32, cls.__shape_from_bytes)
+        )
+        assert qtype not in _type_traits
+        _type_traits[qtype] = cls
+
+    @classmethod
+    @abstractmethod
+    def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+        raise NotImplementedError
+
+    @classmethod
+    @abstractmethod
+    def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+        raise NotImplementedError
+
+    @classmethod
+    def quantize_rows(cls, rows: np.ndarray) -> np.ndarray:
+        rows = rows.astype(np.float32, copy=False)
+        shape = rows.shape
+        n_blocks = rows.size // cls.block_size
+        blocks = rows.reshape((n_blocks, cls.block_size))
+        blocks = cls.quantize_blocks(blocks)
+        assert blocks.dtype == np.uint8
+        assert blocks.shape[-1] == cls.type_size
+        return blocks.reshape(cls.__shape_to_bytes(shape))
+
+    @classmethod
+    def dequantize_rows(cls, rows: np.ndarray) -> np.ndarray:
+        rows = rows.view(np.uint8)
+        shape = rows.shape
+        n_blocks = rows.size // cls.type_size
+        blocks = rows.reshape((n_blocks, cls.type_size))
+        blocks = cls.dequantize_blocks(blocks)
+        assert blocks.dtype == np.float32
+        assert blocks.shape[-1] == cls.block_size
+        return blocks.reshape(cls.__shape_from_bytes(shape))
+
+    @classmethod
+    def __shape_to_bytes(cls, shape: Sequence[int]):
+        return quant_shape_to_byte_shape(shape, cls.qtype)
+
+    @classmethod
+    def __shape_from_bytes(cls, shape: Sequence[int]):
+        return quant_shape_from_byte_shape(shape, cls.qtype)
+
+    @classmethod
+    def __quantize_array(cls, array: np.ndarray) -> np.ndarray:
+        return _apply_over_grouped_rows(cls.quantize_rows, arr=array, otype=np.uint8, oshape=cls.__shape_to_bytes(array.shape))
+
+    @classmethod
+    def __dequantize_array(cls, array: np.ndarray) -> np.ndarray:
+        return _apply_over_grouped_rows(cls.dequantize_rows, arr=array, otype=np.float32, oshape=cls.__shape_from_bytes(array.shape))
+
+    @classmethod
+    def __quantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any:
+        pass
+
+    @classmethod
+    def __dequantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any:
+        pass
+
+    @classmethod
+    def can_quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> bool:
+        return tensor.shape[-1] % cls.block_size == 0
+
+    @classmethod
+    def quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray:
+        if not cls.can_quantize(tensor):
+            raise QuantError(f"Can't quantize tensor with shape {tensor.shape} to {cls.qtype.name}")
+        if isinstance(tensor, LazyNumpyTensor):
+            return cls.__quantize_lazy(tensor)
+        else:
+            return cls.__quantize_array(tensor)
+
+    @classmethod
+    def dequantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray:
+        if isinstance(tensor, LazyNumpyTensor):
+            return cls.__dequantize_lazy(tensor)
+        else:
+            return cls.__dequantize_array(tensor)
+
+
+class BF16(__Quant, qtype=GGMLQuantizationType.BF16):
+    @classmethod
+    # same as ggml_compute_fp32_to_bf16 in ggml-impl.h
+    def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+        n = blocks.view(np.uint32)
+        # force nan to quiet
+        n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n)
+        # round to nearest even
+        n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16
+        return n.astype(np.uint16).view(np.uint8)
+
+    @classmethod
+    def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+        return (blocks.view(np.int16).astype(np.int32) << 16).view(np.float32)
+
+
+class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
+    @classmethod
+    # Implementation of Q8_0 with bit-exact same results as reference implementation in ggml-quants.c
+    def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+
+        d = abs(blocks).max(axis=1, keepdims=True) / 127
+        with np.errstate(divide="ignore"):
+            id = np.where(d == 0, 0, 1 / d)
+        qs = np_roundf(blocks * id)
+
+        # (n_blocks, 2)
+        d = d.astype(np.float16).view(np.uint8)
+        # (n_blocks, block_size)
+        qs = qs.astype(np.int8).view(np.uint8)
+
+        return np.concatenate([d, qs], axis=1)
+
+    @classmethod
+    def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+        d, x = np.split(blocks, [2], axis=1)
+        d = d.view(np.float16).astype(np.float32)
+        x = x.view(np.int8).astype(np.float32)
+
+        return (x * d)