tensor_names_from_parts.update(model_part.keys())
for name in model_part.keys():
- data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
- if self.lazy:
- data = LazyTorchTensor.from_eager(data)
+ if self.is_safetensors:
+ if self.lazy:
+ data = model_part.get_slice(name)
+ data = LazyTorchTensor.from_safetensors_slice(data)
+ else:
+ data = model_part.get_tensor(name)
+ else:
+ data = model_part[name]
+ if self.lazy:
+ data = LazyTorchTensor.from_eager(data)
yield name, data
# only verify tensor name presence; it doesn't matter if they are not in the right files
torch.float32: np.float32,
}
+ # used for safetensors slices
+ # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
+ # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
+ _dtype_str_map: dict[str, torch.dtype] = {
+ "F64": torch.float64,
+ "F32": torch.float32,
+ "BF16": torch.bfloat16,
+ "F16": torch.float16,
+ # "U64": torch.uint64,
+ "I64": torch.int64,
+ # "U32": torch.uint32,
+ "I32": torch.int32,
+ # "U16": torch.uint16,
+ "I16": torch.int16,
+ "U8": torch.uint8,
+ "I8": torch.int8,
+ "BOOL": torch.bool,
+ "F8_E4M3": torch.float8_e4m3fn,
+ "F8_E5M2": torch.float8_e5m2,
+ }
+
def numpy(self) -> gguf.LazyNumpyTensor:
dtype = self._dtype_map[self.dtype]
return gguf.LazyNumpyTensor(
meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
- lazy=self._lazy,
args=(self,),
- func=(lambda s: s[0].numpy())
+ func=(lambda s: s.numpy())
)
@classmethod
- def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: torch.Size) -> Tensor:
+ def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor:
return torch.empty(size=shape, dtype=dtype, device="meta")
+ @classmethod
+ def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
+ dtype = cls._dtype_str_map[st_slice.get_dtype()]
+ shape: tuple[int, ...] = tuple(st_slice.get_shape())
+ lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
+ return cast(torch.Tensor, lazy)
+
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
del types # unused
if func is torch.Tensor.numpy:
return args[0].numpy()
- return LazyTorchTensor._wrap_fn(func)(*args, **kwargs)
+ return cls._wrap_fn(func)(*args, **kwargs)
def parse_args() -> argparse.Namespace:
import logging
from typing import Any, Callable
-from collections import deque
import numpy as np
from numpy.typing import DTypeLike
_tensor_type: type
_meta: Any
_data: Any | None
- _lazy: deque[LazyBase] # shared within a graph, to avoid deep recursion when making eager
_args: tuple
- _func: Callable[[tuple], Any] | None
+ _kwargs: dict[str, Any]
+ _func: Callable[[Any], Any] | None
- def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None):
+ def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None):
super().__init__()
self._meta = meta
self._data = data
- self._lazy = lazy if lazy is not None else deque()
self._args = args
+ self._kwargs = kwargs if kwargs is not None else {}
self._func = func
assert self._func is not None or self._data is not None
- if self._data is None:
- self._lazy.append(self)
def __init_subclass__(cls) -> None:
if "_tensor_type" not in cls.__dict__:
args = ((use_self,) if use_self is not None else ()) + args
meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
+ # TODO: maybe handle tensors in kwargs too
if isinstance(meta_noop, bool) and not meta_noop:
try:
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
if isinstance(res, cls._tensor_type):
- class CollectSharedLazy:
- # emulating a static variable
- shared_lazy: None | deque[LazyBase] = None
-
- @staticmethod
- def collect_replace(t: LazyBase):
- if CollectSharedLazy.shared_lazy is None:
- CollectSharedLazy.shared_lazy = t._lazy
- else:
- CollectSharedLazy.shared_lazy.extend(t._lazy)
- t._lazy = CollectSharedLazy.shared_lazy
-
- LazyBase._recurse_apply(args, CollectSharedLazy.collect_replace)
-
- shared_lazy = CollectSharedLazy.shared_lazy
-
- return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
+ return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
else:
del res # not needed
# non-tensor return likely relies on the contents of the args
@classmethod
def to_eager(cls, t: Any) -> Any:
def simple_to_eager(_t: LazyBase) -> Any:
- def already_eager_to_eager(_t: LazyBase) -> Any:
- assert _t._data is not None
+ if _t._data is not None:
return _t._data
- while _t._data is None:
- lt = _t._lazy.popleft()
- if lt._data is not None:
- # Lazy tensor did not belong in the lazy queue.
- # Weirdly only happens with Bloom models...
- # likely because tensors aren't unique in the queue.
- # The final output is still the same as in eager mode,
- # so it's safe to ignore this.
- continue
- assert lt._func is not None
- lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
- lt._data = lt._func(lt._args)
- # sanity check
- assert lt._data is not None
- assert lt._data.dtype == lt._meta.dtype
- assert lt._data.shape == lt._meta.shape
+ # NOTE: there's a recursion limit in Python (usually 1000)
+
+ assert _t._func is not None
+ _t._args = cls._recurse_apply(_t._args, simple_to_eager)
+ _t._data = _t._func(*_t._args, **_t._kwargs)
+ # sanity check
+ assert _t._data is not None
+ assert _t._data.dtype == _t._meta.dtype
+ assert _t._data.shape == _t._meta.shape
return _t._data
@classmethod
def from_eager(cls, t: Any) -> Any:
if type(t) is cls:
- # already eager
+ # already lazy
return t
elif isinstance(t, cls._tensor_type):
return cls(meta=cls.eager_to_meta(t), data=t)
def astype(self, dtype, *args, **kwargs):
meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
full_args = (self, dtype,) + args
- # very important to pass the shared _lazy deque, or else there's an infinite loop somewhere.
- return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))
+ return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)))
def tofile(self, *args, **kwargs):
eager = LazyNumpyTensor.to_eager(self)