from enum import IntEnum
from pathlib import Path
from hashlib import sha256
-from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast, overload
+from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
import numpy as np
import torch
dir_model: Path
ftype: int
- fname_out: Path
is_big_endian: bool
endianess: gguf.GGUFEndian
use_temp_file: bool
part_names: list[str]
is_safetensors: bool
hparams: dict[str, Any]
- gguf_writer: gguf.GGUFWriter
block_count: int
tensor_map: gguf.TensorNameMap
tensor_names: set[str] | None
+ fname_out: Path
+ gguf_writer: gguf.GGUFWriter
# subclasses should define this!
model_arch: gguf.MODEL_ARCH
- def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
- if self.__class__ == Model:
- raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated")
+ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
+ if type(self) is Model:
+ raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
self.dir_model = dir_model
self.ftype = ftype
- self.fname_out = fname_out
self.is_big_endian = is_big_endian
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
self.use_temp_file = use_temp_file
if not self.is_safetensors:
self.part_names = Model.get_model_part_names(self.dir_model, ".bin")
self.hparams = Model.load_hparams(self.dir_model)
- self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
self.tensor_names = None
+ if self.ftype == gguf.LlamaFileType.GUESSED:
+ # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
+ _, first_tensor = next(self.get_tensors())
+ if first_tensor.dtype == torch.float16:
+ logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})")
+ self.ftype = gguf.LlamaFileType.MOSTLY_F16
+ else:
+ logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
+ self.ftype = gguf.LlamaFileType.MOSTLY_BF16
+ ftype_up: str = self.ftype.name.partition("_")[2].upper()
+ ftype_lw: str = ftype_up.lower()
+ # allow templating the file name with the output ftype, useful with the "auto" ftype
+ self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
+ self.gguf_writer = gguf.GGUFWriter(self.fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
@classmethod
def __init_subclass__(cls):
raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}")
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
- name: str = gguf.TENSOR_NAMES[key]
if key not in gguf.MODEL_TENSORS[self.model_arch]:
raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}")
+ name: str = gguf.TENSOR_NAMES[key]
if "{bid}" in name:
assert bid is not None
name = name.format(bid=bid)
return name + suffix
+ def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> bool:
+ if key not in gguf.MODEL_TENSORS[self.model_arch]:
+ return False
+ key_name: str = gguf.TENSOR_NAMES[key]
+ if "{bid}" in key_name:
+ if bid is None:
+ return False
+ key_name = key_name.format(bid=bid)
+ else:
+ if bid is not None:
+ return False
+ return name == (key_name + suffix)
+
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
if new_name is None:
return False
def write_tensors(self):
+ # same as ggml_compute_fp32_to_bf16 in ggml-impl.h
+ def np_fp32_to_bf16(n: np.ndarray):
+ # force nan to quiet
+ n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n)
+ # flush subnormals to zero
+ n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n)
+ # round to nearest even
+ n = (n + (0x7fff + ((n >> 16) & 1))) >> 16
+ return n.astype(np.int16)
+
+ # Doing this row-wise is much, much faster than element-wise, hence the signature
+ v_fp32_to_bf16 = np.vectorize(np_fp32_to_bf16, otypes=[np.int16], signature="(n)->(n)")
+ if self.lazy:
+ # TODO: find a way to implicitly wrap np.vectorize functions
+ # NOTE: the type is changed to reflect otypes passed to np.vectorize above
+ v_fp32_to_bf16 = gguf.LazyNumpyTensor._wrap_fn(v_fp32_to_bf16, meta_noop=np.int16)
+
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
for name, data_torch in self.get_tensors():
data: np.ndarray = data # type hint
n_dims = len(data.shape)
data_dtype = data.dtype
-
- # if f32 desired, convert any float16 to float32
- if self.ftype == 0 and data_dtype == np.float16:
- data = data.astype(np.float32)
+ data_qtype: gguf.GGMLQuantizationType | None = None
# when both are True, f32 should win
extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims)
# Most of the codebase that takes in 1D tensors or norms only handles F32 tensors
- extra_f32 = extra_f32 or n_dims == 1 or new_name.endswith("_norm.weight")
+ # Conditions should closely match those in llama_model_quantize_internal in llama.cpp
+ extra_f32 = any(cond for cond in (
+ extra_f32,
+ n_dims == 1,
+ new_name.endswith("_norm.weight"),
+ ))
+
+ # Some tensor types are always in float32
+ extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in (
+ gguf.MODEL_TENSOR.FFN_GATE_INP,
+ gguf.MODEL_TENSOR.POS_EMBD,
+ gguf.MODEL_TENSOR.TOKEN_TYPES,
+ ))
# if f16 desired, convert any float32 2-dim weight tensors to float16
- extra_f16 = extra_f16 or (name.endswith(".weight") and n_dims >= 2)
-
- # when both extra_f32 and extra_f16 are False, convert to float32 by default
- if self.ftype == 1 and data_dtype == np.float16 and (extra_f32 or not extra_f16):
- data = data.astype(np.float32)
-
- if self.ftype == 1 and data_dtype == np.float32 and extra_f16 and not extra_f32:
- data = data.astype(np.float16)
+ extra_f16 = any(cond for cond in (
+ extra_f16,
+ (name.endswith(".weight") and n_dims >= 2),
+ ))
+
+ if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
+ if self.ftype == gguf.LlamaFileType.MOSTLY_F16:
+ if data_dtype != np.float16:
+ data = data.astype(np.float16)
+ data_qtype = gguf.GGMLQuantizationType.F16
+
+ elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
+ if data_dtype != np.float32:
+ data = data.astype(np.float32)
+ data = v_fp32_to_bf16(data.view(np.int32))
+ assert data.dtype == np.int16
+ data_qtype = gguf.GGMLQuantizationType.BF16
+
+ else: # by default, convert to float32
+ if data_dtype != np.float32:
+ data = data.astype(np.float32)
+ data_qtype = gguf.GGMLQuantizationType.F32
+
+ assert data_qtype is not None
# reverse shape to make it similar to the internal ggml dimension order
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
# n_dims is implicit in the shape
- logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data.dtype}, shape = {shape_str}")
+ logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
- self.gguf_writer.add_tensor(new_name, data)
+ self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)
def write(self):
self.write_tensors()
return [(self.map_tensor_name(name), data_torch)]
- def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
- del new_name, bid, n_dims # unused
-
- # not used with get_rows, must be F32
- return name == "embeddings.token_type_embeddings.weight"
-
@Model.register("NomicBertModel")
class NomicBertModel(BertModel):
# tree of lazy tensors
-class LazyTorchTensor:
- _meta: Tensor
- _data: Tensor | None
- _args: tuple
- _func: Callable[[tuple], Tensor] | None
-
- def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: tuple = (), func: Callable[[tuple], Tensor] | None = None):
- self._meta = meta
- self._data = data
- self._args = args
- self._func = func
-
- @staticmethod
- def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
- # TODO: dict and set
- if isinstance(o, (list, tuple)):
- L = []
- for item in o:
- L.append(LazyTorchTensor._recurse_apply(item, fn))
- if isinstance(o, tuple):
- L = tuple(L)
- return L
- elif isinstance(o, LazyTorchTensor):
- return fn(o)
- else:
- return o
-
- def _wrap_fn(self, fn: Callable, use_self: bool = False) -> Callable[[Any], LazyTorchTensor]:
- def wrapped_fn(*args, **kwargs):
- if kwargs is None:
- kwargs = {}
- args = ((self,) if use_self else ()) + args
-
- meta_args = LazyTorchTensor._recurse_apply(args, lambda t: t._meta)
-
- return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args, func=lambda a: fn(*a, **kwargs))
- return wrapped_fn
-
- def __getattr__(self, __name: str) -> Any:
- meta_attr = getattr(self._meta, __name)
- if callable(meta_attr):
- return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True)
- elif isinstance(meta_attr, torch.Tensor):
- # for things like self.T
- return self._wrap_fn(lambda s: getattr(s, __name))(self)
- else:
- return meta_attr
+class LazyTorchTensor(gguf.LazyBase):
+ _tensor_type = torch.Tensor
+ # to keep the type-checker happy
+ dtype: torch.dtype
+ shape: torch.Size
+ # only used when converting a torch.Tensor to a np.ndarray
_dtype_map: dict[torch.dtype, type] = {
torch.float16: np.float16,
torch.float32: np.float32,
}
- def numpy(self) -> gguf.LazyTensor:
+ def numpy(self) -> gguf.LazyNumpyTensor:
dtype = self._dtype_map[self.dtype]
- return gguf.LazyTensor(lambda: LazyTorchTensor.to_eager(self).numpy(), dtype=dtype, shape=self.shape)
-
- @overload
- @staticmethod
- def to_eager(t: Tensor | LazyTorchTensor) -> Tensor: ...
-
- @overload
- @staticmethod
- def to_eager(t: tuple) -> tuple: ...
-
- @staticmethod
- def to_eager(t: Any) -> Any:
- def simple_to_eager(_t: LazyTorchTensor) -> Tensor:
- # wake up the lazy tensor
- if _t._data is None and _t._func is not None:
- # recurse into its arguments
- _t._args = LazyTorchTensor.to_eager(_t._args)
- _t._data = _t._func(_t._args)
- if _t._data is not None:
- return _t._data
- else:
- raise ValueError(f"Could not compute lazy tensor {_t!r} with args {_t._args!r}")
-
- # recurse into lists and/or tuples, keeping their structure
- return LazyTorchTensor._recurse_apply(t, simple_to_eager)
+ return gguf.LazyNumpyTensor(
+ meta=np.lib.stride_tricks.as_strided(np.zeros(1, dtype), self.shape, (0 for _ in self.shape)),
+ lazy=self._lazy,
+ args=(self,),
+ func=(lambda s: s[0].numpy())
+ )
- @staticmethod
- def from_eager(t: Tensor) -> Tensor:
- if (t.__class__ == LazyTorchTensor):
+ @classmethod
+ def eager_to_meta(cls, t: Tensor) -> Tensor:
+ if t.is_meta:
return t
- return LazyTorchTensor(meta=t.detach().to("meta"), data=t) # type: ignore
+ return t.detach().to("meta")
+
+ @classmethod
+ def meta_with_dtype(cls, m: Tensor, dtype: torch.dtype) -> Tensor:
+ m = m.detach()
+ if not m.is_meta:
+ m = m.to("meta")
+ m.dtype = dtype
+ return m
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if func is torch.Tensor.numpy:
return args[0].numpy()
- if func is torch.equal:
- eager_args = LazyTorchTensor.to_eager(args)
- return func(*eager_args, **kwargs)
- return LazyTorchTensor._wrap_fn(args[0], func)(*args, **kwargs)
-
- # special methods bypass __getattr__, so they need to be added manually
- # ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
- # NOTE: LazyTorchTensor can't be a subclass of Tensor (and then be used
- # as self._meta is currently used), because then the following
- # operations would by default not be wrapped, and so not propagated
- # when the tensor is made eager.
- # It's better to get non-silent errors for not-yet-supported operators.
- # TODO: add more when needed to avoid clutter, or find a more concise way
- def __neg__(self, *args): # mamba
- return self._wrap_fn(torch.Tensor.__neg__)(self, *args)
-
- def __add__(self, *args): # gemma
- return self._wrap_fn(torch.Tensor.__add__)(self, *args)
-
- def __getitem__(self, *args): # bloom falcon refact internlm2
- return self._wrap_fn(torch.Tensor.__getitem__)(self, *args)
+ return LazyTorchTensor._wrap_fn(func)(*args, **kwargs)
def parse_args() -> argparse.Namespace:
)
parser.add_argument(
"--outfile", type=Path,
- help="path to write to; default: based on input",
+ help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
)
parser.add_argument(
- "--outtype", type=str, choices=["f32", "f16"], default="f16",
- help="output format - use f32 for float32, f16 for float16",
+ "--outtype", type=str, choices=["f32", "f16", "bf16", "auto"], default="f16",
+ help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
)
parser.add_argument(
"--bigendian", action="store_true",
logger.error(f'Error: {args.model} is not a directory')
sys.exit(1)
- ftype_map = {
- "f32": gguf.GGMLQuantizationType.F32,
- "f16": gguf.GGMLQuantizationType.F16,
+ ftype_map: dict[str, gguf.LlamaFileType] = {
+ "f32": gguf.LlamaFileType.ALL_F32,
+ "f16": gguf.LlamaFileType.MOSTLY_F16,
+ "bf16": gguf.LlamaFileType.MOSTLY_BF16,
+ "auto": gguf.LlamaFileType.GUESSED,
}
if args.outfile is not None:
fname_out = args.outfile
else:
# output in the same directory as the model by default
- fname_out = dir_model / f'ggml-model-{args.outtype}.gguf'
+ fname_out = dir_model / 'ggml-model-{ftype}.gguf'
logger.info(f"Loading model: {dir_model.name}")
logger.info("Set model tokenizer")
model_instance.set_vocab()
+ model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
+
if args.vocab_only:
- logger.info(f"Exporting model vocab to '{fname_out}'")
+ logger.info(f"Exporting model vocab to '{model_instance.fname_out}'")
model_instance.write_vocab()
else:
- logger.info(f"Exporting model to '{fname_out}'")
+ logger.info(f"Exporting model to '{model_instance.fname_out}'")
model_instance.write()
- logger.info(f"Model successfully exported to '{fname_out}'")
+ logger.info(f"Model successfully exported to '{model_instance.fname_out}'")
if __name__ == '__main__':
import tempfile
from enum import Enum, auto
from io import BufferedWriter
-from typing import IO, Any, Callable, Sequence, Mapping
+from typing import IO, Any, Sequence, Mapping
from string import ascii_letters, digits
import numpy as np
logger = logging.getLogger(__name__)
-class LazyTensor:
- data: Callable[[], np.ndarray[Any, Any]]
- # to avoid too deep recursion
- functions: list[Callable[[np.ndarray[Any, Any]], np.ndarray[Any, Any]]]
- dtype: np.dtype[Any]
- shape: tuple[int, ...]
-
- def __init__(self, data: Callable[[], np.ndarray[Any, Any]], *, dtype: type, shape: tuple[int, ...]):
- self.data = data
- self.functions = []
- self.dtype = np.dtype(dtype)
- self.shape = shape
-
- def astype(self, dtype: type, **kwargs) -> LazyTensor:
- self.functions.append(lambda n: n.astype(dtype, **kwargs))
- self.dtype = np.dtype(dtype)
- return self
-
- @property
- def nbytes(self) -> int:
- size = 1
- for n in self.shape:
- size *= n
- return size * self.dtype.itemsize
-
- def tofile(self, *args, **kwargs) -> None:
- data = self.data()
- for f in self.functions:
- data = f(data)
- assert data.shape == self.shape
- assert data.dtype == self.dtype
- assert data.nbytes == self.nbytes
- self.functions = []
- self.data = lambda: data
- data.tofile(*args, **kwargs)
-
- def byteswap(self, *args, **kwargs) -> LazyTensor:
- self.functions.append(lambda n: n.byteswap(*args, **kwargs))
- return self
-
-
class WriterState(Enum):
EMPTY = auto()
HEADER = auto()
class GGUFWriter:
fout: BufferedWriter
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
- tensors: list[np.ndarray[Any, Any] | LazyTensor]
+ tensors: list[np.ndarray[Any, Any]]
_simple_value_packing = {
GGUFValueType.UINT8: "B",
GGUFValueType.INT8: "b",
self.ti_data_count += 1
def add_tensor(
- self, name: str, tensor: np.ndarray[Any, Any] | LazyTensor, raw_shape: Sequence[int] | None = None,
+ self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None,
) -> None:
if self.endianess == GGUFEndian.BIG:
if pad != 0:
fp.write(bytes([0] * pad))
- def write_tensor_data(self, tensor: np.ndarray[Any, Any] | LazyTensor) -> None:
+ def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
if self.state is not WriterState.TI_DATA:
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
def add_name(self, name: str) -> None:
self.add_string(Keys.General.NAME, name)
- def add_quantization_version(self, quantization_version: GGMLQuantizationType) -> None:
+ def add_quantization_version(self, quantization_version: int) -> None:
self.add_uint32(
Keys.General.QUANTIZATION_VERSION, quantization_version)
--- /dev/null
+from __future__ import annotations
+from abc import ABC, ABCMeta, abstractmethod
+
+import logging
+from typing import Any, Callable
+from collections import deque
+
+import numpy as np
+from numpy.typing import DTypeLike
+
+
+logger = logging.getLogger(__name__)
+
+
+class LazyMeta(ABCMeta):
+
+ def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs):
+ def __getattr__(self, __name: str) -> Any:
+ meta_attr = getattr(self._meta, __name)
+ if callable(meta_attr):
+ return type(self)._wrap_fn(
+ (lambda s, *args, **kwargs: getattr(s, __name)(*args, **kwargs)),
+ use_self=self,
+ )
+ elif isinstance(meta_attr, self._tensor_type):
+ # e.g. self.T with torch.Tensor should still be wrapped
+ return type(self)._wrap_fn(lambda s: getattr(s, __name))(self)
+ else:
+ # no need to wrap non-tensor properties,
+ # and they likely don't depend on the actual contents of the tensor
+ return meta_attr
+
+ namespace["__getattr__"] = __getattr__
+
+ # need to make a builder for the wrapped wrapper to copy the name,
+ # or else it fails with very cryptic error messages,
+ # because somehow the same string would end up in every closures
+ def mk_wrap(op_name: str, *, meta_noop: bool = False):
+ # need to wrap the wrapper to get self
+ def wrapped_special_op(self, *args, **kwargs):
+ return type(self)._wrap_fn(
+ getattr(type(self)._tensor_type, op_name),
+ meta_noop=meta_noop,
+ )(self, *args, **kwargs)
+ return wrapped_special_op
+
+ # special methods bypass __getattr__, so they need to be added manually
+ # ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
+ # NOTE: doing this from a metaclass is very convenient
+ # TODO: make this even more comprehensive
+ for binary_op in (
+ "lt", "le", "eq", "ne", "ge", "gt", "not"
+ "abs", "add", "and", "floordiv", "invert", "lshift", "mod", "mul", "matmul",
+ "neg", "or", "pos", "pow", "rshift", "sub", "truediv", "xor",
+ "iadd", "iand", "ifloordiv", "ilshift", "imod", "imul", "ior", "irshift", "isub", "ixor",
+ "radd", "rand", "rfloordiv", "rmul", "ror", "rpow", "rsub", "rtruediv", "rxor",
+ ):
+ attr_name = f"__{binary_op}__"
+ # the result of these operators usually has the same shape and dtype as the input,
+ # so evaluation on the meta tensor can be skipped.
+ namespace[attr_name] = mk_wrap(attr_name, meta_noop=True)
+
+ for special_op in (
+ "getitem", "setitem", "len",
+ ):
+ attr_name = f"__{special_op}__"
+ namespace[attr_name] = mk_wrap(attr_name, meta_noop=False)
+
+ return super().__new__(cls, name, bases, namespace, **kwargs)
+
+
+# Tree of lazy tensors
+class LazyBase(ABC, metaclass=LazyMeta):
+ _tensor_type: type
+ _meta: Any
+ _data: Any | None
+ _lazy: deque[LazyBase] # shared within a graph, to avoid deep recursion when making eager
+ _args: tuple
+ _func: Callable[[tuple], Any] | None
+
+ def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None):
+ super().__init__()
+ self._meta = meta
+ self._data = data
+ self._lazy = lazy if lazy is not None else deque()
+ self._args = args
+ self._func = func
+ assert self._func is not None or self._data is not None
+ if self._data is None:
+ self._lazy.append(self)
+
+ def __init_subclass__(cls) -> None:
+ if "_tensor_type" not in cls.__dict__:
+ raise TypeError(f"property '_tensor_type' must be defined for {cls!r}")
+ return super().__init_subclass__()
+
+ @staticmethod
+ def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
+ # TODO: dict and set
+ if isinstance(o, (list, tuple)):
+ L = []
+ for item in o:
+ L.append(LazyBase._recurse_apply(item, fn))
+ if isinstance(o, tuple):
+ L = tuple(L)
+ return L
+ elif isinstance(o, LazyBase):
+ return fn(o)
+ else:
+ return o
+
+ @classmethod
+ def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike = False) -> Callable[[Any], Any]:
+ def wrapped_fn(*args, **kwargs):
+ if kwargs is None:
+ kwargs = {}
+ args = ((use_self,) if use_self is not None else ()) + args
+
+ meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
+
+ if isinstance(meta_noop, bool) and not meta_noop:
+ try:
+ res = fn(*meta_args, **kwargs)
+ except NotImplementedError:
+ # running some operations on PyTorch's Meta tensors can cause this exception
+ res = None
+ else:
+ # some operators don't need to actually run on the meta tensors
+ assert len(args) > 0
+ res = args[0]
+ assert isinstance(res, cls)
+ res = res._meta
+ # allow operations to override the dtype
+ if meta_noop is not True:
+ res = cls.meta_with_dtype(res, meta_noop)
+
+ if isinstance(res, cls._tensor_type):
+ def collect_replace(t: LazyBase):
+ if collect_replace.shared_lazy is None:
+ collect_replace.shared_lazy = t._lazy
+ else:
+ collect_replace.shared_lazy.extend(t._lazy)
+ t._lazy = collect_replace.shared_lazy
+
+ # emulating a static variable
+ collect_replace.shared_lazy = None
+
+ LazyBase._recurse_apply(args, collect_replace)
+
+ shared_lazy = collect_replace.shared_lazy
+
+ return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
+ else:
+ del res # not needed
+ # non-tensor return likely relies on the contents of the args
+ # (e.g. the result of torch.equal)
+ eager_args = cls.to_eager(args)
+ return fn(*eager_args, **kwargs)
+ return wrapped_fn
+
+ @classmethod
+ def to_eager(cls, t: Any) -> Any:
+ def simple_to_eager(_t: LazyBase) -> Any:
+ def already_eager_to_eager(_t: LazyBase) -> Any:
+ assert _t._data is not None
+ return _t._data
+
+ while _t._data is None:
+ lt = _t._lazy.popleft()
+ if lt._data is not None:
+ raise ValueError(f"{lt} did not belong in the lazy queue")
+ assert lt._func is not None
+ lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
+ lt._data = lt._func(lt._args)
+ # sanity check
+ assert lt._data.dtype == lt._meta.dtype
+ assert lt._data.shape == lt._meta.shape
+
+ return _t._data
+
+ # recurse into lists and/or tuples, keeping their structure
+ return cls._recurse_apply(t, simple_to_eager)
+
+ @classmethod
+ def eager_to_meta(cls, t: Any) -> Any:
+ return cls.meta_with_dtype(t, t.dtype)
+
+ # must be overridden, meta tensor init is backend-specific
+ @classmethod
+ @abstractmethod
+ def meta_with_dtype(cls, m: Any, dtype: Any) -> Any: pass
+
+ @classmethod
+ def from_eager(cls, t: Any) -> Any:
+ if type(t) is cls:
+ # already eager
+ return t
+ elif isinstance(t, cls._tensor_type):
+ return cls(meta=cls.eager_to_meta(t), data=t)
+ else:
+ return TypeError(f"{type(t)!r} is not compatible with {cls._tensor_type!r}")
+
+
+class LazyNumpyTensor(LazyBase):
+ _tensor_type = np.ndarray
+
+ @classmethod
+ def meta_with_dtype(cls, m: np.ndarray[Any, Any], dtype: DTypeLike) -> np.ndarray[Any, Any]:
+ # The initial idea was to use np.nan as the fill value,
+ # but non-float types like np.int16 can't use that.
+ # So zero it is.
+ cheat = np.zeros(1, dtype)
+ return np.lib.stride_tricks.as_strided(cheat, m.shape, (0 for _ in m.shape))
+
+ def astype(self, dtype, *args, **kwargs):
+ meta = type(self).meta_with_dtype(self._meta, dtype)
+ full_args = (self, dtype,) + args
+ # very important to pass the shared _lazy deque, or else there's an infinite loop somewhere.
+ return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))
+
+ def tofile(self, *args, **kwargs):
+ eager = LazyNumpyTensor.to_eager(self)
+ return eager.tofile(*args, **kwargs)
+
+ # TODO: __array_function__