]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
convert-hf : support bfloat16 conversion (#7158)
authorcompilade <redacted>
Sat, 11 May 2024 15:06:26 +0000 (11:06 -0400)
committerGitHub <redacted>
Sat, 11 May 2024 15:06:26 +0000 (11:06 -0400)
* convert-hf : support bfloat16 conversion

* gguf-py : flake8 fixes

* convert-hf : add missing space after comma

* convert-hf : get bit-exact same output as ./quantize

The quantization version was missing.

* convert-hf : don't round bf16 NANs

* convert-hf : save some memory with np.int16 intermediate bf16 weights

* convert-hf : more closely match llama.cpp with which weights to keep in f32

* convert-hf : add --outtype auto-f16

A reason for this to exist is for model quantizers who want an initial
GGUF with the most fidelity to the original model while still using
a 16-bit float type instead of 32-bit floats.

* convert-hf : remove a semicolon because flake8 doesn't like it

It's a reflex from when programming in C/C++, I guess.

* convert-hf : support outtype templating in outfile name

* convert-hf : rename --outtype auto-f16 to --outtype auto

convert-hf-to-gguf.py
gguf-py/gguf/__init__.py
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/lazy.py [new file with mode: 0644]

index fbaed64da1cac4d12569f35701357df927e2078d..ec7f4dd758c72c67fd8b67dce9cdf1e95ccec05d 100755 (executable)
@@ -12,7 +12,7 @@ import sys
 from enum import IntEnum
 from pathlib import Path
 from hashlib import sha256
-from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast, overload
+from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
 
 import numpy as np
 import torch
@@ -48,7 +48,6 @@ class Model:
 
     dir_model: Path
     ftype: int
-    fname_out: Path
     is_big_endian: bool
     endianess: gguf.GGUFEndian
     use_temp_file: bool
@@ -56,20 +55,20 @@ class Model:
     part_names: list[str]
     is_safetensors: bool
     hparams: dict[str, Any]
-    gguf_writer: gguf.GGUFWriter
     block_count: int
     tensor_map: gguf.TensorNameMap
     tensor_names: set[str] | None
+    fname_out: Path
+    gguf_writer: gguf.GGUFWriter
 
     # subclasses should define this!
     model_arch: gguf.MODEL_ARCH
 
-    def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
-        if self.__class__ == Model:
-            raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated")
+    def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
+        if type(self) is Model:
+            raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
         self.dir_model = dir_model
         self.ftype = ftype
-        self.fname_out = fname_out
         self.is_big_endian = is_big_endian
         self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
         self.use_temp_file = use_temp_file
@@ -79,10 +78,23 @@ class Model:
         if not self.is_safetensors:
             self.part_names = Model.get_model_part_names(self.dir_model, ".bin")
         self.hparams = Model.load_hparams(self.dir_model)
-        self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
         self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
         self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
         self.tensor_names = None
+        if self.ftype == gguf.LlamaFileType.GUESSED:
+            # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
+            _, first_tensor = next(self.get_tensors())
+            if first_tensor.dtype == torch.float16:
+                logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})")
+                self.ftype = gguf.LlamaFileType.MOSTLY_F16
+            else:
+                logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
+                self.ftype = gguf.LlamaFileType.MOSTLY_BF16
+        ftype_up: str = self.ftype.name.partition("_")[2].upper()
+        ftype_lw: str = ftype_up.lower()
+        # allow templating the file name with the output ftype, useful with the "auto" ftype
+        self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
+        self.gguf_writer = gguf.GGUFWriter(self.fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
 
     @classmethod
     def __init_subclass__(cls):
@@ -142,14 +154,27 @@ class Model:
             raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}")
 
     def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
-        name: str = gguf.TENSOR_NAMES[key]
         if key not in gguf.MODEL_TENSORS[self.model_arch]:
             raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}")
+        name: str = gguf.TENSOR_NAMES[key]
         if "{bid}" in name:
             assert bid is not None
             name = name.format(bid=bid)
         return name + suffix
 
+    def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> bool:
+        if key not in gguf.MODEL_TENSORS[self.model_arch]:
+            return False
+        key_name: str = gguf.TENSOR_NAMES[key]
+        if "{bid}" in key_name:
+            if bid is None:
+                return False
+            key_name = key_name.format(bid=bid)
+        else:
+            if bid is not None:
+                return False
+        return name == (key_name + suffix)
+
     def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
         new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
         if new_name is None:
@@ -215,6 +240,23 @@ class Model:
         return False
 
     def write_tensors(self):
+        # same as ggml_compute_fp32_to_bf16 in ggml-impl.h
+        def np_fp32_to_bf16(n: np.ndarray):
+            # force nan to quiet
+            n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n)
+            # flush subnormals to zero
+            n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n)
+            # round to nearest even
+            n = (n + (0x7fff + ((n >> 16) & 1))) >> 16
+            return n.astype(np.int16)
+
+        # Doing this row-wise is much, much faster than element-wise, hence the signature
+        v_fp32_to_bf16 = np.vectorize(np_fp32_to_bf16, otypes=[np.int16], signature="(n)->(n)")
+        if self.lazy:
+            # TODO: find a way to implicitly wrap np.vectorize functions
+            # NOTE: the type is changed to reflect otypes passed to np.vectorize above
+            v_fp32_to_bf16 = gguf.LazyNumpyTensor._wrap_fn(v_fp32_to_bf16, meta_noop=np.int16)
+
         max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
 
         for name, data_torch in self.get_tensors():
@@ -239,35 +281,60 @@ class Model:
                 data: np.ndarray = data  # type hint
                 n_dims = len(data.shape)
                 data_dtype = data.dtype
-
-                # if f32 desired, convert any float16 to float32
-                if self.ftype == 0 and data_dtype == np.float16:
-                    data = data.astype(np.float32)
+                data_qtype: gguf.GGMLQuantizationType | None = None
 
                 # when both are True, f32 should win
                 extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
                 extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims)
 
                 # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors
-                extra_f32 = extra_f32 or n_dims == 1 or new_name.endswith("_norm.weight")
+                # Conditions should closely match those in llama_model_quantize_internal in llama.cpp
+                extra_f32 = any(cond for cond in (
+                    extra_f32,
+                    n_dims == 1,
+                    new_name.endswith("_norm.weight"),
+                ))
+
+                # Some tensor types are always in float32
+                extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in (
+                    gguf.MODEL_TENSOR.FFN_GATE_INP,
+                    gguf.MODEL_TENSOR.POS_EMBD,
+                    gguf.MODEL_TENSOR.TOKEN_TYPES,
+                ))
 
                 # if f16 desired, convert any float32 2-dim weight tensors to float16
-                extra_f16 = extra_f16 or (name.endswith(".weight") and n_dims >= 2)
-
-                # when both extra_f32 and extra_f16 are False, convert to float32 by default
-                if self.ftype == 1 and data_dtype == np.float16 and (extra_f32 or not extra_f16):
-                    data = data.astype(np.float32)
-
-                if self.ftype == 1 and data_dtype == np.float32 and extra_f16 and not extra_f32:
-                    data = data.astype(np.float16)
+                extra_f16 = any(cond for cond in (
+                    extra_f16,
+                    (name.endswith(".weight") and n_dims >= 2),
+                ))
+
+                if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
+                    if self.ftype == gguf.LlamaFileType.MOSTLY_F16:
+                        if data_dtype != np.float16:
+                            data = data.astype(np.float16)
+                        data_qtype = gguf.GGMLQuantizationType.F16
+
+                    elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
+                        if data_dtype != np.float32:
+                            data = data.astype(np.float32)
+                        data = v_fp32_to_bf16(data.view(np.int32))
+                        assert data.dtype == np.int16
+                        data_qtype = gguf.GGMLQuantizationType.BF16
+
+                else:  # by default, convert to float32
+                    if data_dtype != np.float32:
+                        data = data.astype(np.float32)
+                    data_qtype = gguf.GGMLQuantizationType.F32
+
+                assert data_qtype is not None
 
                 # reverse shape to make it similar to the internal ggml dimension order
                 shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
 
                 # n_dims is implicit in the shape
-                logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data.dtype}, shape = {shape_str}")
+                logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
 
-                self.gguf_writer.add_tensor(new_name, data)
+                self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)
 
     def write(self):
         self.write_tensors()
@@ -2044,12 +2111,6 @@ class BertModel(Model):
 
         return [(self.map_tensor_name(name), data_torch)]
 
-    def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
-        del new_name, bid, n_dims  # unused
-
-        # not used with get_rows, must be F32
-        return name == "embeddings.token_type_embeddings.weight"
-
 
 @Model.register("NomicBertModel")
 class NomicBertModel(BertModel):
@@ -2339,92 +2400,40 @@ class JinaBertV2Model(BertModel):
 
 
 # tree of lazy tensors
-class LazyTorchTensor:
-    _meta: Tensor
-    _data: Tensor | None
-    _args: tuple
-    _func: Callable[[tuple], Tensor] | None
-
-    def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: tuple = (), func: Callable[[tuple], Tensor] | None = None):
-        self._meta = meta
-        self._data = data
-        self._args = args
-        self._func = func
-
-    @staticmethod
-    def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
-        # TODO: dict and set
-        if isinstance(o, (list, tuple)):
-            L = []
-            for item in o:
-                L.append(LazyTorchTensor._recurse_apply(item, fn))
-            if isinstance(o, tuple):
-                L = tuple(L)
-            return L
-        elif isinstance(o, LazyTorchTensor):
-            return fn(o)
-        else:
-            return o
-
-    def _wrap_fn(self, fn: Callable, use_self: bool = False) -> Callable[[Any], LazyTorchTensor]:
-        def wrapped_fn(*args, **kwargs):
-            if kwargs is None:
-                kwargs = {}
-            args = ((self,) if use_self else ()) + args
-
-            meta_args = LazyTorchTensor._recurse_apply(args, lambda t: t._meta)
-
-            return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args, func=lambda a: fn(*a, **kwargs))
-        return wrapped_fn
-
-    def __getattr__(self, __name: str) -> Any:
-        meta_attr = getattr(self._meta, __name)
-        if callable(meta_attr):
-            return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True)
-        elif isinstance(meta_attr, torch.Tensor):
-            # for things like self.T
-            return self._wrap_fn(lambda s: getattr(s, __name))(self)
-        else:
-            return meta_attr
+class LazyTorchTensor(gguf.LazyBase):
+    _tensor_type = torch.Tensor
+    # to keep the type-checker happy
+    dtype: torch.dtype
+    shape: torch.Size
 
+    # only used when converting a torch.Tensor to a np.ndarray
     _dtype_map: dict[torch.dtype, type] = {
         torch.float16: np.float16,
         torch.float32: np.float32,
     }
 
-    def numpy(self) -> gguf.LazyTensor:
+    def numpy(self) -> gguf.LazyNumpyTensor:
         dtype = self._dtype_map[self.dtype]
-        return gguf.LazyTensor(lambda: LazyTorchTensor.to_eager(self).numpy(), dtype=dtype, shape=self.shape)
-
-    @overload
-    @staticmethod
-    def to_eager(t: Tensor | LazyTorchTensor) -> Tensor: ...
-
-    @overload
-    @staticmethod
-    def to_eager(t: tuple) -> tuple: ...
-
-    @staticmethod
-    def to_eager(t: Any) -> Any:
-        def simple_to_eager(_t: LazyTorchTensor) -> Tensor:
-            # wake up the lazy tensor
-            if _t._data is None and _t._func is not None:
-                # recurse into its arguments
-                _t._args = LazyTorchTensor.to_eager(_t._args)
-                _t._data = _t._func(_t._args)
-            if _t._data is not None:
-                return _t._data
-            else:
-                raise ValueError(f"Could not compute lazy tensor {_t!r} with args {_t._args!r}")
-
-        # recurse into lists and/or tuples, keeping their structure
-        return LazyTorchTensor._recurse_apply(t, simple_to_eager)
+        return gguf.LazyNumpyTensor(
+            meta=np.lib.stride_tricks.as_strided(np.zeros(1, dtype), self.shape, (0 for _ in self.shape)),
+            lazy=self._lazy,
+            args=(self,),
+            func=(lambda s: s[0].numpy())
+        )
 
-    @staticmethod
-    def from_eager(t: Tensor) -> Tensor:
-        if (t.__class__ == LazyTorchTensor):
+    @classmethod
+    def eager_to_meta(cls, t: Tensor) -> Tensor:
+        if t.is_meta:
             return t
-        return LazyTorchTensor(meta=t.detach().to("meta"), data=t)  # type: ignore
+        return t.detach().to("meta")
+
+    @classmethod
+    def meta_with_dtype(cls, m: Tensor, dtype: torch.dtype) -> Tensor:
+        m = m.detach()
+        if not m.is_meta:
+            m = m.to("meta")
+        m.dtype = dtype
+        return m
 
     @classmethod
     def __torch_function__(cls, func, types, args=(), kwargs=None):
@@ -2435,28 +2444,8 @@ class LazyTorchTensor:
 
         if func is torch.Tensor.numpy:
             return args[0].numpy()
-        if func is torch.equal:
-            eager_args = LazyTorchTensor.to_eager(args)
-            return func(*eager_args, **kwargs)
 
-        return LazyTorchTensor._wrap_fn(args[0], func)(*args, **kwargs)
-
-    # special methods bypass __getattr__, so they need to be added manually
-    # ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
-    # NOTE: LazyTorchTensor can't be a subclass of Tensor (and then be used
-    #       as self._meta is currently used), because then the following
-    #       operations would by default not be wrapped, and so not propagated
-    #       when the tensor is made eager.
-    #       It's better to get non-silent errors for not-yet-supported operators.
-    # TODO: add more when needed to avoid clutter, or find a more concise way
-    def __neg__(self, *args):  # mamba
-        return self._wrap_fn(torch.Tensor.__neg__)(self, *args)
-
-    def __add__(self, *args):  # gemma
-        return self._wrap_fn(torch.Tensor.__add__)(self, *args)
-
-    def __getitem__(self, *args):  # bloom falcon refact internlm2
-        return self._wrap_fn(torch.Tensor.__getitem__)(self, *args)
+        return LazyTorchTensor._wrap_fn(func)(*args, **kwargs)
 
 
 def parse_args() -> argparse.Namespace:
@@ -2472,11 +2461,11 @@ def parse_args() -> argparse.Namespace:
     )
     parser.add_argument(
         "--outfile", type=Path,
-        help="path to write to; default: based on input",
+        help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
     )
     parser.add_argument(
-        "--outtype", type=str, choices=["f32", "f16"], default="f16",
-        help="output format - use f32 for float32, f16 for float16",
+        "--outtype", type=str, choices=["f32", "f16", "bf16", "auto"], default="f16",
+        help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
     )
     parser.add_argument(
         "--bigendian", action="store_true",
@@ -2530,16 +2519,18 @@ def main() -> None:
         logger.error(f'Error: {args.model} is not a directory')
         sys.exit(1)
 
-    ftype_map = {
-        "f32": gguf.GGMLQuantizationType.F32,
-        "f16": gguf.GGMLQuantizationType.F16,
+    ftype_map: dict[str, gguf.LlamaFileType] = {
+        "f32": gguf.LlamaFileType.ALL_F32,
+        "f16": gguf.LlamaFileType.MOSTLY_F16,
+        "bf16": gguf.LlamaFileType.MOSTLY_BF16,
+        "auto": gguf.LlamaFileType.GUESSED,
     }
 
     if args.outfile is not None:
         fname_out = args.outfile
     else:
         # output in the same directory as the model by default
-        fname_out = dir_model / f'ggml-model-{args.outtype}.gguf'
+        fname_out = dir_model / 'ggml-model-{ftype}.gguf'
 
     logger.info(f"Loading model: {dir_model.name}")
 
@@ -2555,14 +2546,16 @@ def main() -> None:
         logger.info("Set model tokenizer")
         model_instance.set_vocab()
 
+        model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
+
         if args.vocab_only:
-            logger.info(f"Exporting model vocab to '{fname_out}'")
+            logger.info(f"Exporting model vocab to '{model_instance.fname_out}'")
             model_instance.write_vocab()
         else:
-            logger.info(f"Exporting model to '{fname_out}'")
+            logger.info(f"Exporting model to '{model_instance.fname_out}'")
             model_instance.write()
 
-        logger.info(f"Model successfully exported to '{fname_out}'")
+        logger.info(f"Model successfully exported to '{model_instance.fname_out}'")
 
 
 if __name__ == '__main__':
index 110ab342ccd719fe9617b73c51c0f38ca6c7872f..e5d5806c81e5e370730b8e182ab1becc513a4438 100644 (file)
@@ -1,4 +1,5 @@
 from .constants import *
+from .lazy import *
 from .gguf_reader import *
 from .gguf_writer import *
 from .tensor_mapping import *
index a4fbfc5e09d0632840c42fad1af76609b300069a..978fcada3b42c58ff80a6b99ba3131e7cc4a4d9b 100644 (file)
@@ -10,6 +10,7 @@ from typing import Any
 GGUF_MAGIC             = 0x46554747  # "GGUF"
 GGUF_VERSION           = 3
 GGUF_DEFAULT_ALIGNMENT = 32
+GGML_QUANT_VERSION     = 2  # GGML_QNT_VERSION from ggml.h
 
 #
 # metadata keys
@@ -838,6 +839,49 @@ class GGMLQuantizationType(IntEnum):
     BF16    = 30
 
 
+# TODO: add GGMLFileType from ggml_ftype in ggml.h
+
+
+# from llama_ftype in llama.h
+# ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE.
+class LlamaFileType(IntEnum):
+    ALL_F32              = 0
+    MOSTLY_F16           = 1   # except 1d tensors
+    MOSTLY_Q4_0          = 2   # except 1d tensors
+    MOSTLY_Q4_1          = 3   # except 1d tensors
+    MOSTLY_Q4_1_SOME_F16 = 4   # tok_embeddings.weight and output.weight are F16
+    # MOSTLY_Q4_2        = 5   # support has been removed
+    # MOSTLY_Q4_3        = 6   # support has been removed
+    MOSTLY_Q8_0          = 7   # except 1d tensors
+    MOSTLY_Q5_0          = 8   # except 1d tensors
+    MOSTLY_Q5_1          = 9   # except 1d tensors
+    MOSTLY_Q2_K          = 10  # except 1d tensors
+    MOSTLY_Q3_K_S        = 11  # except 1d tensors
+    MOSTLY_Q3_K_M        = 12  # except 1d tensors
+    MOSTLY_Q3_K_L        = 13  # except 1d tensors
+    MOSTLY_Q4_K_S        = 14  # except 1d tensors
+    MOSTLY_Q4_K_M        = 15  # except 1d tensors
+    MOSTLY_Q5_K_S        = 16  # except 1d tensors
+    MOSTLY_Q5_K_M        = 17  # except 1d tensors
+    MOSTLY_Q6_K          = 18  # except 1d tensors
+    MOSTLY_IQ2_XXS       = 19  # except 1d tensors
+    MOSTLY_IQ2_XS        = 20  # except 1d tensors
+    MOSTLY_Q2_K_S        = 21  # except 1d tensors
+    MOSTLY_IQ3_XS        = 22  # except 1d tensors
+    MOSTLY_IQ3_XXS       = 23  # except 1d tensors
+    MOSTLY_IQ1_S         = 24  # except 1d tensors
+    MOSTLY_IQ4_NL        = 25  # except 1d tensors
+    MOSTLY_IQ3_S         = 26  # except 1d tensors
+    MOSTLY_IQ3_M         = 27  # except 1d tensors
+    MOSTLY_IQ2_S         = 28  # except 1d tensors
+    MOSTLY_IQ2_M         = 29  # except 1d tensors
+    MOSTLY_IQ4_XS        = 30  # except 1d tensors
+    MOSTLY_IQ1_M         = 31  # except 1d tensors
+    MOSTLY_BF16          = 32  # except 1d tensors
+
+    GUESSED              = 1024  # not specified in the model file
+
+
 class GGUFEndian(IntEnum):
     LITTLE = 0
     BIG = 1
index 8dcf9330b076fb35b875fb9600a268364bed73a1..96574358d66bb1575bfaa6876802da020d68d4b7 100644 (file)
@@ -7,7 +7,7 @@ import struct
 import tempfile
 from enum import Enum, auto
 from io import BufferedWriter
-from typing import IO, Any, Callable, Sequence, Mapping
+from typing import IO, Any, Sequence, Mapping
 from string import ascii_letters, digits
 
 import numpy as np
@@ -28,47 +28,6 @@ from .constants import (
 logger = logging.getLogger(__name__)
 
 
-class LazyTensor:
-    data: Callable[[], np.ndarray[Any, Any]]
-    # to avoid too deep recursion
-    functions: list[Callable[[np.ndarray[Any, Any]], np.ndarray[Any, Any]]]
-    dtype: np.dtype[Any]
-    shape: tuple[int, ...]
-
-    def __init__(self, data: Callable[[], np.ndarray[Any, Any]], *, dtype: type, shape: tuple[int, ...]):
-        self.data = data
-        self.functions = []
-        self.dtype = np.dtype(dtype)
-        self.shape = shape
-
-    def astype(self, dtype: type, **kwargs) -> LazyTensor:
-        self.functions.append(lambda n: n.astype(dtype, **kwargs))
-        self.dtype = np.dtype(dtype)
-        return self
-
-    @property
-    def nbytes(self) -> int:
-        size = 1
-        for n in self.shape:
-            size *= n
-        return size * self.dtype.itemsize
-
-    def tofile(self, *args, **kwargs) -> None:
-        data = self.data()
-        for f in self.functions:
-            data = f(data)
-        assert data.shape == self.shape
-        assert data.dtype == self.dtype
-        assert data.nbytes == self.nbytes
-        self.functions = []
-        self.data = lambda: data
-        data.tofile(*args, **kwargs)
-
-    def byteswap(self, *args, **kwargs) -> LazyTensor:
-        self.functions.append(lambda n: n.byteswap(*args, **kwargs))
-        return self
-
-
 class WriterState(Enum):
     EMPTY   = auto()
     HEADER  = auto()
@@ -79,7 +38,7 @@ class WriterState(Enum):
 class GGUFWriter:
     fout: BufferedWriter
     temp_file: tempfile.SpooledTemporaryFile[bytes] | None
-    tensors: list[np.ndarray[Any, Any] | LazyTensor]
+    tensors: list[np.ndarray[Any, Any]]
     _simple_value_packing = {
         GGUFValueType.UINT8:   "B",
         GGUFValueType.INT8:    "b",
@@ -278,7 +237,7 @@ class GGUFWriter:
         self.ti_data_count += 1
 
     def add_tensor(
-        self, name: str, tensor: np.ndarray[Any, Any] | LazyTensor, raw_shape: Sequence[int] | None = None,
+        self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
         raw_dtype: GGMLQuantizationType | None = None,
     ) -> None:
         if self.endianess == GGUFEndian.BIG:
@@ -303,7 +262,7 @@ class GGUFWriter:
         if pad != 0:
             fp.write(bytes([0] * pad))
 
-    def write_tensor_data(self, tensor: np.ndarray[Any, Any] | LazyTensor) -> None:
+    def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
         if self.state is not WriterState.TI_DATA:
             raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
 
@@ -391,7 +350,7 @@ class GGUFWriter:
     def add_name(self, name: str) -> None:
         self.add_string(Keys.General.NAME, name)
 
-    def add_quantization_version(self, quantization_version: GGMLQuantizationType) -> None:
+    def add_quantization_version(self, quantization_version: int) -> None:
         self.add_uint32(
             Keys.General.QUANTIZATION_VERSION, quantization_version)
 
diff --git a/gguf-py/gguf/lazy.py b/gguf-py/gguf/lazy.py
new file mode 100644 (file)
index 0000000..650bea1
--- /dev/null
@@ -0,0 +1,225 @@
+from __future__ import annotations
+from abc import ABC, ABCMeta, abstractmethod
+
+import logging
+from typing import Any, Callable
+from collections import deque
+
+import numpy as np
+from numpy.typing import DTypeLike
+
+
+logger = logging.getLogger(__name__)
+
+
+class LazyMeta(ABCMeta):
+
+    def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs):
+        def __getattr__(self, __name: str) -> Any:
+            meta_attr = getattr(self._meta, __name)
+            if callable(meta_attr):
+                return type(self)._wrap_fn(
+                    (lambda s, *args, **kwargs: getattr(s, __name)(*args, **kwargs)),
+                    use_self=self,
+                )
+            elif isinstance(meta_attr, self._tensor_type):
+                # e.g. self.T with torch.Tensor should still be wrapped
+                return type(self)._wrap_fn(lambda s: getattr(s, __name))(self)
+            else:
+                # no need to wrap non-tensor properties,
+                # and they likely don't depend on the actual contents of the tensor
+                return meta_attr
+
+        namespace["__getattr__"] = __getattr__
+
+        # need to make a builder for the wrapped wrapper to copy the name,
+        # or else it fails with very cryptic error messages,
+        # because somehow the same string would end up in every closures
+        def mk_wrap(op_name: str, *, meta_noop: bool = False):
+            # need to wrap the wrapper to get self
+            def wrapped_special_op(self, *args, **kwargs):
+                return type(self)._wrap_fn(
+                    getattr(type(self)._tensor_type, op_name),
+                    meta_noop=meta_noop,
+                )(self, *args, **kwargs)
+            return wrapped_special_op
+
+        # special methods bypass __getattr__, so they need to be added manually
+        # ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
+        # NOTE: doing this from a metaclass is very convenient
+        # TODO: make this even more comprehensive
+        for binary_op in (
+            "lt", "le", "eq", "ne", "ge", "gt", "not"
+            "abs", "add", "and", "floordiv", "invert", "lshift", "mod", "mul", "matmul",
+            "neg", "or", "pos", "pow", "rshift", "sub", "truediv", "xor",
+            "iadd", "iand", "ifloordiv", "ilshift", "imod", "imul", "ior", "irshift", "isub", "ixor",
+            "radd", "rand", "rfloordiv", "rmul", "ror", "rpow", "rsub", "rtruediv", "rxor",
+        ):
+            attr_name = f"__{binary_op}__"
+            # the result of these operators usually has the same shape and dtype as the input,
+            # so evaluation on the meta tensor can be skipped.
+            namespace[attr_name] = mk_wrap(attr_name, meta_noop=True)
+
+        for special_op in (
+            "getitem", "setitem", "len",
+        ):
+            attr_name = f"__{special_op}__"
+            namespace[attr_name] = mk_wrap(attr_name, meta_noop=False)
+
+        return super().__new__(cls, name, bases, namespace, **kwargs)
+
+
+# Tree of lazy tensors
+class LazyBase(ABC, metaclass=LazyMeta):
+    _tensor_type: type
+    _meta: Any
+    _data: Any | None
+    _lazy: deque[LazyBase]  # shared within a graph, to avoid deep recursion when making eager
+    _args: tuple
+    _func: Callable[[tuple], Any] | None
+
+    def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None):
+        super().__init__()
+        self._meta = meta
+        self._data = data
+        self._lazy = lazy if lazy is not None else deque()
+        self._args = args
+        self._func = func
+        assert self._func is not None or self._data is not None
+        if self._data is None:
+            self._lazy.append(self)
+
+    def __init_subclass__(cls) -> None:
+        if "_tensor_type" not in cls.__dict__:
+            raise TypeError(f"property '_tensor_type' must be defined for {cls!r}")
+        return super().__init_subclass__()
+
+    @staticmethod
+    def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
+        # TODO: dict and set
+        if isinstance(o, (list, tuple)):
+            L = []
+            for item in o:
+                L.append(LazyBase._recurse_apply(item, fn))
+            if isinstance(o, tuple):
+                L = tuple(L)
+            return L
+        elif isinstance(o, LazyBase):
+            return fn(o)
+        else:
+            return o
+
+    @classmethod
+    def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike = False) -> Callable[[Any], Any]:
+        def wrapped_fn(*args, **kwargs):
+            if kwargs is None:
+                kwargs = {}
+            args = ((use_self,) if use_self is not None else ()) + args
+
+            meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
+
+            if isinstance(meta_noop, bool) and not meta_noop:
+                try:
+                    res = fn(*meta_args, **kwargs)
+                except NotImplementedError:
+                    # running some operations on PyTorch's Meta tensors can cause this exception
+                    res = None
+            else:
+                # some operators don't need to actually run on the meta tensors
+                assert len(args) > 0
+                res = args[0]
+                assert isinstance(res, cls)
+                res = res._meta
+                # allow operations to override the dtype
+                if meta_noop is not True:
+                    res = cls.meta_with_dtype(res, meta_noop)
+
+            if isinstance(res, cls._tensor_type):
+                def collect_replace(t: LazyBase):
+                    if collect_replace.shared_lazy is None:
+                        collect_replace.shared_lazy = t._lazy
+                    else:
+                        collect_replace.shared_lazy.extend(t._lazy)
+                        t._lazy = collect_replace.shared_lazy
+
+                # emulating a static variable
+                collect_replace.shared_lazy = None
+
+                LazyBase._recurse_apply(args, collect_replace)
+
+                shared_lazy = collect_replace.shared_lazy
+
+                return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
+            else:
+                del res  # not needed
+                # non-tensor return likely relies on the contents of the args
+                # (e.g. the result of torch.equal)
+                eager_args = cls.to_eager(args)
+                return fn(*eager_args, **kwargs)
+        return wrapped_fn
+
+    @classmethod
+    def to_eager(cls, t: Any) -> Any:
+        def simple_to_eager(_t: LazyBase) -> Any:
+            def already_eager_to_eager(_t: LazyBase) -> Any:
+                assert _t._data is not None
+                return _t._data
+
+            while _t._data is None:
+                lt = _t._lazy.popleft()
+                if lt._data is not None:
+                    raise ValueError(f"{lt} did not belong in the lazy queue")
+                assert lt._func is not None
+                lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
+                lt._data = lt._func(lt._args)
+                # sanity check
+                assert lt._data.dtype == lt._meta.dtype
+                assert lt._data.shape == lt._meta.shape
+
+            return _t._data
+
+        # recurse into lists and/or tuples, keeping their structure
+        return cls._recurse_apply(t, simple_to_eager)
+
+    @classmethod
+    def eager_to_meta(cls, t: Any) -> Any:
+        return cls.meta_with_dtype(t, t.dtype)
+
+    # must be overridden, meta tensor init is backend-specific
+    @classmethod
+    @abstractmethod
+    def meta_with_dtype(cls, m: Any, dtype: Any) -> Any: pass
+
+    @classmethod
+    def from_eager(cls, t: Any) -> Any:
+        if type(t) is cls:
+            # already eager
+            return t
+        elif isinstance(t, cls._tensor_type):
+            return cls(meta=cls.eager_to_meta(t), data=t)
+        else:
+            return TypeError(f"{type(t)!r} is not compatible with {cls._tensor_type!r}")
+
+
+class LazyNumpyTensor(LazyBase):
+    _tensor_type = np.ndarray
+
+    @classmethod
+    def meta_with_dtype(cls, m: np.ndarray[Any, Any], dtype: DTypeLike) -> np.ndarray[Any, Any]:
+        # The initial idea was to use np.nan as the fill value,
+        # but non-float types like np.int16 can't use that.
+        # So zero it is.
+        cheat = np.zeros(1, dtype)
+        return np.lib.stride_tricks.as_strided(cheat, m.shape, (0 for _ in m.shape))
+
+    def astype(self, dtype, *args, **kwargs):
+        meta = type(self).meta_with_dtype(self._meta, dtype)
+        full_args = (self, dtype,) + args
+        # very important to pass the shared _lazy deque, or else there's an infinite loop somewhere.
+        return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))
+
+    def tofile(self, *args, **kwargs):
+        eager = LazyNumpyTensor.to_eager(self)
+        return eager.tofile(*args, **kwargs)
+
+    # TODO: __array_function__