return [(self.map_tensor_name(name), data_torch)]
- def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
- del name, new_name, bid, n_dims # unused
-
- return False
-
- def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
+ def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
del name, new_name, bid, n_dims # unused
return False
for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
data: np.ndarray # type hint
n_dims = len(data.shape)
- data_dtype = data.dtype
- data_qtype: gguf.GGMLQuantizationType | None = None
-
- # when both are True, f32 should win
- extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
- extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims)
+ data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims)
# Most of the codebase that takes in 1D tensors or norms only handles F32 tensors
- # Conditions should closely match those in llama_model_quantize_internal in llama.cpp
- extra_f32 = any(cond for cond in (
- extra_f32,
- n_dims == 1,
- new_name.endswith("_norm.weight"),
- ))
+ if n_dims <= 1 or new_name.endswith("_norm.weight"):
+ data_qtype = gguf.GGMLQuantizationType.F32
+ # Conditions should closely match those in llama_model_quantize_internal in llama.cpp
# Some tensor types are always in float32
- extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in (
- gguf.MODEL_TENSOR.FFN_GATE_INP,
- gguf.MODEL_TENSOR.POS_EMBD,
- gguf.MODEL_TENSOR.TOKEN_TYPES,
- ))
-
- # if f16 desired, convert any float32 2-dim weight tensors to float16
- extra_f16 = any(cond for cond in (
- extra_f16,
- (name.endswith(".weight") and n_dims >= 2),
- ))
-
- if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
- if self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
- data = gguf.quantize_bf16(data)
- assert data.dtype == np.uint16
- data_qtype = gguf.GGMLQuantizationType.BF16
-
- elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data):
- data = gguf.quantize_q8_0(data)
- assert data.dtype == np.uint8
- data_qtype = gguf.GGMLQuantizationType.Q8_0
+ if data_qtype is False and (
+ any(
+ self.match_model_tensor_name(new_name, key, bid)
+ for key in (
+ gguf.MODEL_TENSOR.FFN_GATE_INP,
+ gguf.MODEL_TENSOR.POS_EMBD,
+ gguf.MODEL_TENSOR.TOKEN_TYPES,
+ )
+ )
+ or not name.endswith(".weight")
+ ):
+ data_qtype = gguf.GGMLQuantizationType.F32
- else: # default to float16 for quantized tensors
- if data_dtype != np.float16:
- data = data.astype(np.float16)
+ # No override (data_qtype is False), or wants to be quantized (data_qtype is True)
+ if isinstance(data_qtype, bool):
+ if self.ftype == gguf.LlamaFileType.ALL_F32:
+ data_qtype = gguf.GGMLQuantizationType.F32
+ elif self.ftype == gguf.LlamaFileType.MOSTLY_F16:
data_qtype = gguf.GGMLQuantizationType.F16
+ elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
+ data_qtype = gguf.GGMLQuantizationType.BF16
+ elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
+ data_qtype = gguf.GGMLQuantizationType.Q8_0
+ else:
+ raise ValueError(f"Unknown file type: {self.ftype.name}")
- if data_qtype is None: # by default, convert to float32
- if data_dtype != np.float32:
- data = data.astype(np.float32)
- data_qtype = gguf.GGMLQuantizationType.F32
+ try:
+ data = gguf.quants.quantize(data, data_qtype)
+ except gguf.QuantError as e:
+ logger.warning("%s, %s", e, "falling back to F16")
+ data_qtype = gguf.GGMLQuantizationType.F16
+ data = gguf.quants.quantize(data, data_qtype)
shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape
return [(new_name, data_torch)]
- def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
+ def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
del name, new_name, bid # unused
return n_dims > 1
return [(new_name, data_torch)]
- def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
- del n_dims # unused
-
- return bid is not None and new_name in (
- self.format_tensor_name(n, bid, ".weight" if name.endswith(".weight") else "") for n in [
+ def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
+ if bid is not None and new_name in (
+ self.format_tensor_name(
+ n, bid, ".weight" if name.endswith(".weight") else ""
+ )
+ for n in [
gguf.MODEL_TENSOR.SSM_CONV1D,
gguf.MODEL_TENSOR.SSM_X,
gguf.MODEL_TENSOR.SSM_DT,
gguf.MODEL_TENSOR.SSM_A,
gguf.MODEL_TENSOR.SSM_D,
]
- )
+ ):
+ return gguf.GGMLQuantizationType.F32
+
+ return super().tensor_force_quant(name, new_name, bid, n_dims)
@Model.register("CohereForCausalLM")
from __future__ import annotations
-from typing import Callable, Sequence
+from abc import ABC, abstractmethod
+from typing import Any, Callable, Sequence
from numpy.typing import DTypeLike
import numpy as np
-def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType):
+def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
block_size, type_size = GGML_QUANT_SIZES[quant_type]
if shape[-1] % block_size != 0:
raise ValueError(f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})")
return (*shape[:-1], shape[-1] // block_size * type_size)
-def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType):
+def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
block_size, type_size = GGML_QUANT_SIZES[quant_type]
if shape[-1] % type_size != 0:
raise ValueError(f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {quant_type.name} type size ({type_size})")
return (*shape[:-1], shape[-1] // type_size * block_size)
-# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
-def __compute_fp32_to_bf16(n: np.ndarray) -> np.ndarray:
- n = n.astype(np.float32, copy=False).view(np.uint32)
- # force nan to quiet
- n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n)
- # round to nearest even
- n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16
- return n.astype(np.uint16)
-
-
# This is faster than np.vectorize and np.apply_along_axis because it works on more than one row at a time
-def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray:
+def _apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray:
rows = arr.reshape((-1, arr.shape[-1]))
osize = 1
for dim in oshape:
return out.reshape(oshape)
-def __quantize_bf16_array(n: np.ndarray) -> np.ndarray:
- return __apply_over_grouped_rows(__compute_fp32_to_bf16, arr=n, otype=np.uint16, oshape=n.shape)
-
-
-__quantize_bf16_lazy = LazyNumpyTensor._wrap_fn(__quantize_bf16_array, meta_noop=np.uint16)
-
-
-def quantize_bf16(n: np.ndarray):
- if type(n) is LazyNumpyTensor:
- return __quantize_bf16_lazy(n)
- else:
- return __quantize_bf16_array(n)
-
-
-__q8_block_size, __q8_type_size = GGML_QUANT_SIZES[GGMLQuantizationType.Q8_0]
-
-
-def can_quantize_to_q8_0(n: np.ndarray) -> bool:
- return n.shape[-1] % __q8_block_size == 0
-
-
# round away from zero
# ref: https://stackoverflow.com/a/59143326/22827863
def np_roundf(n: np.ndarray) -> np.ndarray:
return np.sign(n) * b
-def __quantize_q8_0_shape_change(s: tuple[int, ...]) -> tuple[int, ...]:
- return (*s[:-1], s[-1] // __q8_block_size * __q8_type_size)
-
-
-# Implementation of Q8_0 with bit-exact same results as reference implementation in ggml-quants.c
-def __quantize_q8_0_rows(n: np.ndarray) -> np.ndarray:
- shape = n.shape
- assert shape[-1] % __q8_block_size == 0
-
- n_blocks = n.size // __q8_block_size
-
- blocks = n.reshape((n_blocks, __q8_block_size)).astype(np.float32, copy=False)
+class QuantError(Exception): ...
- d = abs(blocks).max(axis=1, keepdims=True) / 127
- with np.errstate(divide="ignore"):
- id = np.where(d == 0, 0, 1 / d)
- qs = np_roundf(blocks * id)
- # (n_blocks, 2)
- d = d.astype(np.float16).view(np.uint8)
- # (n_blocks, block_size)
- qs = qs.astype(np.int8).view(np.uint8)
+_type_traits: dict[GGMLQuantizationType, type[__Quant]] = {}
- assert d.shape[1] + qs.shape[1] == __q8_type_size
- return np.concatenate([d, qs], axis=1).reshape(__quantize_q8_0_shape_change(shape))
-
-
-def __quantize_q8_0_array(n: np.ndarray) -> np.ndarray:
- return __apply_over_grouped_rows(__quantize_q8_0_rows, arr=n, otype=np.uint8, oshape=__quantize_q8_0_shape_change(n.shape))
-
-
-__quantize_q8_0_lazy = LazyNumpyTensor._wrap_fn(
- __quantize_q8_0_array,
- meta_noop=(np.uint8, __quantize_q8_0_shape_change),
-)
+def quantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
+ if qtype == GGMLQuantizationType.F32:
+ return data.astype(np.float32, copy=False)
+ elif qtype == GGMLQuantizationType.F16:
+ return data.astype(np.float16, copy=False)
+ elif (q := _type_traits.get(qtype)) is not None:
+ return q.quantize(data)
+ else:
+ raise NotImplementedError(f"Quantization for {qtype.name} is not yet implemented")
-def quantize_q8_0(data: np.ndarray):
- if type(data) is LazyNumpyTensor:
- return __quantize_q8_0_lazy(data)
+def dequantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
+ if qtype == GGMLQuantizationType.F32 or qtype == GGMLQuantizationType.F16:
+ return data.astype(np.float32, copy=False)
+ elif (q := _type_traits.get(qtype)) is not None:
+ return q.dequantize(data)
else:
- return __quantize_q8_0_array(data)
+ raise NotImplementedError(f"Dequantization for {qtype.name} is not yet implemented")
+
+
+class __Quant(ABC):
+ qtype: GGMLQuantizationType
+ block_size: int
+ type_size: int
+
+ def __init__(self):
+ return TypeError("Quant conversion classes can't have instances")
+
+ def __init_subclass__(cls, qtype: GGMLQuantizationType) -> None:
+ cls.qtype = qtype
+ cls.block_size, cls.type_size = GGML_QUANT_SIZES[qtype]
+ cls.__quantize_lazy = LazyNumpyTensor._wrap_fn(
+ cls.__quantize_array,
+ meta_noop=(np.uint8, cls.__shape_to_bytes)
+ )
+ cls.__dequantize_lazy = LazyNumpyTensor._wrap_fn(
+ cls.__dequantize_array,
+ meta_noop=(np.float32, cls.__shape_from_bytes)
+ )
+ assert qtype not in _type_traits
+ _type_traits[qtype] = cls
+
+ @classmethod
+ @abstractmethod
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ raise NotImplementedError
+
+ @classmethod
+ @abstractmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ raise NotImplementedError
+
+ @classmethod
+ def quantize_rows(cls, rows: np.ndarray) -> np.ndarray:
+ rows = rows.astype(np.float32, copy=False)
+ shape = rows.shape
+ n_blocks = rows.size // cls.block_size
+ blocks = rows.reshape((n_blocks, cls.block_size))
+ blocks = cls.quantize_blocks(blocks)
+ assert blocks.dtype == np.uint8
+ assert blocks.shape[-1] == cls.type_size
+ return blocks.reshape(cls.__shape_to_bytes(shape))
+
+ @classmethod
+ def dequantize_rows(cls, rows: np.ndarray) -> np.ndarray:
+ rows = rows.view(np.uint8)
+ shape = rows.shape
+ n_blocks = rows.size // cls.type_size
+ blocks = rows.reshape((n_blocks, cls.type_size))
+ blocks = cls.dequantize_blocks(blocks)
+ assert blocks.dtype == np.float32
+ assert blocks.shape[-1] == cls.block_size
+ return blocks.reshape(cls.__shape_from_bytes(shape))
+
+ @classmethod
+ def __shape_to_bytes(cls, shape: Sequence[int]):
+ return quant_shape_to_byte_shape(shape, cls.qtype)
+
+ @classmethod
+ def __shape_from_bytes(cls, shape: Sequence[int]):
+ return quant_shape_from_byte_shape(shape, cls.qtype)
+
+ @classmethod
+ def __quantize_array(cls, array: np.ndarray) -> np.ndarray:
+ return _apply_over_grouped_rows(cls.quantize_rows, arr=array, otype=np.uint8, oshape=cls.__shape_to_bytes(array.shape))
+
+ @classmethod
+ def __dequantize_array(cls, array: np.ndarray) -> np.ndarray:
+ return _apply_over_grouped_rows(cls.dequantize_rows, arr=array, otype=np.float32, oshape=cls.__shape_from_bytes(array.shape))
+
+ @classmethod
+ def __quantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any:
+ pass
+
+ @classmethod
+ def __dequantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any:
+ pass
+
+ @classmethod
+ def can_quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> bool:
+ return tensor.shape[-1] % cls.block_size == 0
+
+ @classmethod
+ def quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray:
+ if not cls.can_quantize(tensor):
+ raise QuantError(f"Can't quantize tensor with shape {tensor.shape} to {cls.qtype.name}")
+ if isinstance(tensor, LazyNumpyTensor):
+ return cls.__quantize_lazy(tensor)
+ else:
+ return cls.__quantize_array(tensor)
+
+ @classmethod
+ def dequantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray:
+ if isinstance(tensor, LazyNumpyTensor):
+ return cls.__dequantize_lazy(tensor)
+ else:
+ return cls.__dequantize_array(tensor)
+
+
+class BF16(__Quant, qtype=GGMLQuantizationType.BF16):
+ @classmethod
+ # same as ggml_compute_fp32_to_bf16 in ggml-impl.h
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ n = blocks.view(np.uint32)
+ # force nan to quiet
+ n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n)
+ # round to nearest even
+ n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16
+ return n.astype(np.uint16).view(np.uint8)
+
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ return (blocks.view(np.int16).astype(np.int32) << 16).view(np.float32)
+
+
+class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
+ @classmethod
+ # Implementation of Q8_0 with bit-exact same results as reference implementation in ggml-quants.c
+ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+
+ d = abs(blocks).max(axis=1, keepdims=True) / 127
+ with np.errstate(divide="ignore"):
+ id = np.where(d == 0, 0, 1 / d)
+ qs = np_roundf(blocks * id)
+
+ # (n_blocks, 2)
+ d = d.astype(np.float16).view(np.uint8)
+ # (n_blocks, block_size)
+ qs = qs.astype(np.int8).view(np.uint8)
+
+ return np.concatenate([d, qs], axis=1)
+
+ @classmethod
+ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+ d, x = np.split(blocks, [2], axis=1)
+ d = d.view(np.float16).astype(np.float32)
+ x = x.view(np.int8).astype(np.float32)
+
+ return (x * d)