]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
convert_hf : faster lazy safetensors (#8482)
authorcompilade <redacted>
Tue, 16 Jul 2024 03:13:10 +0000 (23:13 -0400)
committerGitHub <redacted>
Tue, 16 Jul 2024 03:13:10 +0000 (23:13 -0400)
* convert_hf : faster lazy safetensors

This makes '--dry-run' much, much faster.

* convert_hf : fix memory leak in lazy MoE conversion

The '_lazy' queue was sometimes self-referential,
which caused reference cycles of objects old enough
to avoid garbage collection until potential memory exhaustion.

convert_hf_to_gguf.py
gguf-py/gguf/lazy.py
gguf-py/gguf/tensor_mapping.py

index a755b0a60bf0a7b962bb1a2d9c69bbe71d76940b..c2aba909706d0a8fa7dbeca5346a3a79b9e6270c 100755 (executable)
@@ -148,9 +148,16 @@ class Model:
                 tensor_names_from_parts.update(model_part.keys())
 
                 for name in model_part.keys():
-                    data = model_part.get_tensor(name) if self.is_safetensors else model_part[name]
-                    if self.lazy:
-                        data = LazyTorchTensor.from_eager(data)
+                    if self.is_safetensors:
+                        if self.lazy:
+                            data = model_part.get_slice(name)
+                            data = LazyTorchTensor.from_safetensors_slice(data)
+                        else:
+                            data = model_part.get_tensor(name)
+                    else:
+                        data = model_part[name]
+                        if self.lazy:
+                            data = LazyTorchTensor.from_eager(data)
                     yield name, data
 
         # only verify tensor name presence; it doesn't matter if they are not in the right files
@@ -3424,19 +3431,46 @@ class LazyTorchTensor(gguf.LazyBase):
         torch.float32: np.float32,
     }
 
+    # used for safetensors slices
+    # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
+    # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
+    _dtype_str_map: dict[str, torch.dtype] = {
+        "F64": torch.float64,
+        "F32": torch.float32,
+        "BF16": torch.bfloat16,
+        "F16": torch.float16,
+        # "U64": torch.uint64,
+        "I64": torch.int64,
+        # "U32": torch.uint32,
+        "I32": torch.int32,
+        # "U16": torch.uint16,
+        "I16": torch.int16,
+        "U8": torch.uint8,
+        "I8": torch.int8,
+        "BOOL": torch.bool,
+        "F8_E4M3": torch.float8_e4m3fn,
+        "F8_E5M2": torch.float8_e5m2,
+    }
+
     def numpy(self) -> gguf.LazyNumpyTensor:
         dtype = self._dtype_map[self.dtype]
         return gguf.LazyNumpyTensor(
             meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
-            lazy=self._lazy,
             args=(self,),
-            func=(lambda s: s[0].numpy())
+            func=(lambda s: s.numpy())
         )
 
     @classmethod
-    def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: torch.Size) -> Tensor:
+    def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor:
         return torch.empty(size=shape, dtype=dtype, device="meta")
 
+    @classmethod
+    def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
+        dtype = cls._dtype_str_map[st_slice.get_dtype()]
+        shape: tuple[int, ...] = tuple(st_slice.get_shape())
+        lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
+        return cast(torch.Tensor, lazy)
+
     @classmethod
     def __torch_function__(cls, func, types, args=(), kwargs=None):
         del types  # unused
@@ -3447,7 +3481,7 @@ class LazyTorchTensor(gguf.LazyBase):
         if func is torch.Tensor.numpy:
             return args[0].numpy()
 
-        return LazyTorchTensor._wrap_fn(func)(*args, **kwargs)
+        return cls._wrap_fn(func)(*args, **kwargs)
 
 
 def parse_args() -> argparse.Namespace:
index 6e266f34f76ace877eca3be5ff096c7e4c72dcb4..ac98d9a92a3e9e34efca19491c0c032e9e3526d0 100644 (file)
@@ -3,7 +3,6 @@ from abc import ABC, ABCMeta, abstractmethod
 
 import logging
 from typing import Any, Callable
-from collections import deque
 
 import numpy as np
 from numpy.typing import DTypeLike
@@ -74,20 +73,18 @@ class LazyBase(ABC, metaclass=LazyMeta):
     _tensor_type: type
     _meta: Any
     _data: Any | None
-    _lazy: deque[LazyBase]  # shared within a graph, to avoid deep recursion when making eager
     _args: tuple
-    _func: Callable[[tuple], Any] | None
+    _kwargs: dict[str, Any]
+    _func: Callable[[Any], Any] | None
 
-    def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None):
+    def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None):
         super().__init__()
         self._meta = meta
         self._data = data
-        self._lazy = lazy if lazy is not None else deque()
         self._args = args
+        self._kwargs = kwargs if kwargs is not None else {}
         self._func = func
         assert self._func is not None or self._data is not None
-        if self._data is None:
-            self._lazy.append(self)
 
     def __init_subclass__(cls) -> None:
         if "_tensor_type" not in cls.__dict__:
@@ -117,6 +114,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
             args = ((use_self,) if use_self is not None else ()) + args
 
             meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
+            # TODO: maybe handle tensors in kwargs too
 
             if isinstance(meta_noop, bool) and not meta_noop:
                 try:
@@ -140,23 +138,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
                         res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
 
             if isinstance(res, cls._tensor_type):
-                class CollectSharedLazy:
-                    # emulating a static variable
-                    shared_lazy: None | deque[LazyBase] = None
-
-                    @staticmethod
-                    def collect_replace(t: LazyBase):
-                        if CollectSharedLazy.shared_lazy is None:
-                            CollectSharedLazy.shared_lazy = t._lazy
-                        else:
-                            CollectSharedLazy.shared_lazy.extend(t._lazy)
-                            t._lazy = CollectSharedLazy.shared_lazy
-
-                LazyBase._recurse_apply(args, CollectSharedLazy.collect_replace)
-
-                shared_lazy = CollectSharedLazy.shared_lazy
-
-                return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
+                return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
             else:
                 del res  # not needed
                 # non-tensor return likely relies on the contents of the args
@@ -168,26 +150,18 @@ class LazyBase(ABC, metaclass=LazyMeta):
     @classmethod
     def to_eager(cls, t: Any) -> Any:
         def simple_to_eager(_t: LazyBase) -> Any:
-            def already_eager_to_eager(_t: LazyBase) -> Any:
-                assert _t._data is not None
+            if _t._data is not None:
                 return _t._data
 
-            while _t._data is None:
-                lt = _t._lazy.popleft()
-                if lt._data is not None:
-                    # Lazy tensor did not belong in the lazy queue.
-                    # Weirdly only happens with Bloom models...
-                    # likely because tensors aren't unique in the queue.
-                    # The final output is still the same as in eager mode,
-                    # so it's safe to ignore this.
-                    continue
-                assert lt._func is not None
-                lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
-                lt._data = lt._func(lt._args)
-                # sanity check
-                assert lt._data is not None
-                assert lt._data.dtype == lt._meta.dtype
-                assert lt._data.shape == lt._meta.shape
+            # NOTE: there's a recursion limit in Python (usually 1000)
+
+            assert _t._func is not None
+            _t._args = cls._recurse_apply(_t._args, simple_to_eager)
+            _t._data = _t._func(*_t._args, **_t._kwargs)
+            # sanity check
+            assert _t._data is not None
+            assert _t._data.dtype == _t._meta.dtype
+            assert _t._data.shape == _t._meta.shape
 
             return _t._data
 
@@ -206,7 +180,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
     @classmethod
     def from_eager(cls, t: Any) -> Any:
         if type(t) is cls:
-            # already eager
+            # already lazy
             return t
         elif isinstance(t, cls._tensor_type):
             return cls(meta=cls.eager_to_meta(t), data=t)
@@ -228,8 +202,7 @@ class LazyNumpyTensor(LazyBase):
     def astype(self, dtype, *args, **kwargs):
         meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
         full_args = (self, dtype,) + args
-        # very important to pass the shared _lazy deque, or else there's an infinite loop somewhere.
-        return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))
+        return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)))
 
     def tofile(self, *args, **kwargs):
         eager = LazyNumpyTensor.to_eager(self)
index 7264240f5e17a2967f4dd86febf4564d5fca7376..9aa2209e2f69f8f7ad3b6fbcc3b78e4dbf864349 100644 (file)
@@ -602,14 +602,12 @@ class TensorNameMap:
             for tensor, keys in self.block_mappings_cfg.items():
                 if tensor not in MODEL_TENSORS[arch]:
                     continue
-                # TODO: make this configurable
-                n_experts = 160
-                for xid in range(n_experts):
-                    tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
-                    self.mapping[tensor_name] = (tensor, tensor_name)
-                    for key in keys:
-                        key = key.format(bid = bid, xid = xid)
-                        self.mapping[key] = (tensor, tensor_name)
+
+                tensor_name = TENSOR_NAMES[tensor].format(bid = bid)
+                self.mapping[tensor_name] = (tensor, tensor_name)
+                for key in keys:
+                    key = key.format(bid = bid)
+                    self.mapping[key] = (tensor, tensor_name)
 
     def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
         result = self.mapping.get(key)