logger.info(f"gguf: indexing model part '{part_name}'")
ctx: ContextManager[Any]
if is_safetensors:
- from safetensors import safe_open
- ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
+ ctx = cast(ContextManager[Any], gguf.utility.SafetensorsLocal(self.dir_model / part_name))
else:
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
for name in model_part.keys():
if is_safetensors:
+ data: gguf.utility.LocalTensor = model_part[name]
if self.lazy:
- data = model_part.get_slice(name)
- data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731
+ data_gen = lambda data=data: LazyTorchTensor.from_local_tensor(data) # noqa: E731
else:
- data = model_part.get_tensor(name)
- data_gen = lambda data=data: data # noqa: E731
+ dtype = LazyTorchTensor._dtype_str_map[data.dtype]
+ data_gen = lambda data=data, dtype=dtype: torch.from_numpy(data.mmap_bytes()).view(dtype).reshape(data.shape) # noqa: E731
else:
- data = model_part[name]
+ data_torch: Tensor = model_part[name]
if self.lazy:
- data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731
+ data_gen = lambda data=data_torch: LazyTorchTensor.from_eager(data) # noqa: E731
else:
- data_gen = lambda data=data: data # noqa: E731
+ data_gen = lambda data=data_torch: data # noqa: E731
tensors[name] = data_gen
# verify tensor name presence and identify potentially missing files
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[...] if len(s.get_shape()) == 0 else s[:])
return cast(torch.Tensor, lazy)
+ @classmethod
+ def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor:
+ def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
+ dtype = cls._dtype_str_map[tensor.dtype]
+ return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape)
+ dtype = cls._dtype_str_map[t.dtype]
+ shape = t.shape
+ lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r))
+ return cast(torch.Tensor, lazy)
+
@classmethod
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
dtype = cls._dtype_str_map[remote_tensor.dtype]
from __future__ import annotations
from dataclasses import dataclass
+from pathlib import Path
from typing import Literal
import os
import json
+import numpy as np
def fill_templated_filename(filename: str, output_type: str | None) -> str:
except KeyError as e:
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
+ # order by name (same as default safetensors behavior)
+ # ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
+ res = dict(sorted(res.items(), key=lambda t: t[0]))
+
return res
@classmethod
if os.environ.get("HF_TOKEN"):
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
return headers
+
+
+@dataclass
+class LocalTensorRange:
+ filename: Path
+ offset: int
+ size: int
+
+
+@dataclass
+class LocalTensor:
+ dtype: str
+ shape: tuple[int, ...]
+ data_range: LocalTensorRange
+
+ def mmap_bytes(self) -> np.ndarray:
+ return np.memmap(self.data_range.filename, offset=self.data_range.offset, shape=self.data_range.size)
+
+
+class SafetensorsLocal:
+ """
+ Read a safetensors file from the local filesystem.
+
+ Custom parsing gives a bit more control over the memory usage.
+ The official safetensors library doesn't expose file ranges.
+ """
+ ALIGNMENT = 8 # bytes
+
+ tensors: dict[str, LocalTensor]
+
+ def __init__(self, filename: Path):
+ with open(filename, "rb") as f:
+ metadata_length = int.from_bytes(f.read(8), byteorder='little')
+ file_size = os.stat(filename).st_size
+ if file_size < 8 + metadata_length:
+ raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {file_size}")
+
+ metadata_str = f.read(metadata_length).decode('utf-8')
+ try:
+ metadata = json.loads(metadata_str)
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Failed to parse safetensors metadata as JSON: {e}")
+
+ data_start_offset = f.tell()
+ alignment = self.ALIGNMENT
+ if data_start_offset % alignment != 0:
+ data_start_offset += alignment - (data_start_offset % alignment)
+
+ tensors: dict[str, LocalTensor] = {}
+ for name, meta in metadata.items():
+ if name == "__metadata__":
+ # ignore metadata, it's not a tensor
+ continue
+
+ tensors[name] = LocalTensor(
+ dtype=meta["dtype"],
+ shape=tuple(meta["shape"]),
+ data_range=LocalTensorRange(
+ filename,
+ data_start_offset + meta["data_offsets"][0],
+ meta["data_offsets"][1] - meta["data_offsets"][0],
+ ),
+ )
+
+ # order by name (same as default safetensors behavior)
+ # ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
+ self.tensors = dict(sorted(tensors.items(), key=lambda t: t[0]))
+
+ def __enter__(self, *args, **kwargs):
+ del args, kwargs # unused
+ return self.tensors
+
+ def __exit__(self, *args, **kwargs):
+ del args, kwargs # unused