]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
convert : parse safetensors directly (#15667)
authorcompilade <redacted>
Sun, 9 Nov 2025 14:49:40 +0000 (09:49 -0500)
committerGitHub <redacted>
Sun, 9 Nov 2025 14:49:40 +0000 (09:49 -0500)
* convert : parse safetensors directly

* gguf-py : order safetensors tensors by name

Applies to both local and remote safetensors custom parsing.
This matches the behavior of the official safetensors implementation.

* convert : rename from_safetensors_meta to from_local_tensor

For consistency with from_remote_tensor

* convert : fix no-lazy dtypes from direct safetensors

convert_hf_to_gguf.py
gguf-py/gguf/utility.py

index b155d112b1ace5d2743129e572fc5783d48c9c5b..13448fd68116cc4ec37601b16e709f755e1afa25 100755 (executable)
@@ -218,8 +218,7 @@ class ModelBase:
             logger.info(f"gguf: indexing model part '{part_name}'")
             ctx: ContextManager[Any]
             if is_safetensors:
-                from safetensors import safe_open
-                ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
+                ctx = cast(ContextManager[Any], gguf.utility.SafetensorsLocal(self.dir_model / part_name))
             else:
                 ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
 
@@ -228,18 +227,18 @@ class ModelBase:
 
                 for name in model_part.keys():
                     if is_safetensors:
+                        data: gguf.utility.LocalTensor = model_part[name]
                         if self.lazy:
-                            data = model_part.get_slice(name)
-                            data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data)  # noqa: E731
+                            data_gen = lambda data=data: LazyTorchTensor.from_local_tensor(data)  # noqa: E731
                         else:
-                            data = model_part.get_tensor(name)
-                            data_gen = lambda data=data: data  # noqa: E731
+                            dtype = LazyTorchTensor._dtype_str_map[data.dtype]
+                            data_gen = lambda data=data, dtype=dtype: torch.from_numpy(data.mmap_bytes()).view(dtype).reshape(data.shape)  # noqa: E731
                     else:
-                        data = model_part[name]
+                        data_torch: Tensor = model_part[name]
                         if self.lazy:
-                            data_gen = lambda data=data: LazyTorchTensor.from_eager(data)  # noqa: E731
+                            data_gen = lambda data=data_torch: LazyTorchTensor.from_eager(data)  # noqa: E731
                         else:
-                            data_gen = lambda data=data: data  # noqa: E731
+                            data_gen = lambda data=data_torch: data  # noqa: E731
                     tensors[name] = data_gen
 
         # verify tensor name presence and identify potentially missing files
@@ -10079,6 +10078,16 @@ class LazyTorchTensor(gguf.LazyBase):
         lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[...] if len(s.get_shape()) == 0 else s[:])
         return cast(torch.Tensor, lazy)
 
+    @classmethod
+    def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor:
+        def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
+            dtype = cls._dtype_str_map[tensor.dtype]
+            return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape)
+        dtype = cls._dtype_str_map[t.dtype]
+        shape = t.shape
+        lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r))
+        return cast(torch.Tensor, lazy)
+
     @classmethod
     def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
         dtype = cls._dtype_str_map[remote_tensor.dtype]
index 769ccb02f0d91c51e6de812becc57b7eb7317e58..c9401a1c0a2d3104e05834efc77c5b67dfdc3a52 100644 (file)
@@ -1,10 +1,12 @@
 from __future__ import annotations
 
 from dataclasses import dataclass
+from pathlib import Path
 from typing import Literal
 
 import os
 import json
+import numpy as np
 
 
 def fill_templated_filename(filename: str, output_type: str | None) -> str:
@@ -177,6 +179,10 @@ class SafetensorRemote:
             except KeyError as e:
                 raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
 
+        # order by name (same as default safetensors behavior)
+        # ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
+        res = dict(sorted(res.items(), key=lambda t: t[0]))
+
         return res
 
     @classmethod
@@ -266,3 +272,77 @@ class SafetensorRemote:
         if os.environ.get("HF_TOKEN"):
             headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
         return headers
+
+
+@dataclass
+class LocalTensorRange:
+    filename: Path
+    offset: int
+    size: int
+
+
+@dataclass
+class LocalTensor:
+    dtype: str
+    shape: tuple[int, ...]
+    data_range: LocalTensorRange
+
+    def mmap_bytes(self) -> np.ndarray:
+        return np.memmap(self.data_range.filename, offset=self.data_range.offset, shape=self.data_range.size)
+
+
+class SafetensorsLocal:
+    """
+        Read a safetensors file from the local filesystem.
+
+        Custom parsing gives a bit more control over the memory usage.
+        The official safetensors library doesn't expose file ranges.
+    """
+    ALIGNMENT = 8  # bytes
+
+    tensors: dict[str, LocalTensor]
+
+    def __init__(self, filename: Path):
+        with open(filename, "rb") as f:
+            metadata_length = int.from_bytes(f.read(8), byteorder='little')
+            file_size = os.stat(filename).st_size
+            if file_size < 8 + metadata_length:
+                raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {file_size}")
+
+            metadata_str = f.read(metadata_length).decode('utf-8')
+            try:
+                metadata = json.loads(metadata_str)
+            except json.JSONDecodeError as e:
+                raise ValueError(f"Failed to parse safetensors metadata as JSON: {e}")
+
+            data_start_offset = f.tell()
+            alignment = self.ALIGNMENT
+            if data_start_offset % alignment != 0:
+                data_start_offset += alignment - (data_start_offset % alignment)
+
+            tensors: dict[str, LocalTensor] = {}
+            for name, meta in metadata.items():
+                if name == "__metadata__":
+                    # ignore metadata, it's not a tensor
+                    continue
+
+                tensors[name] = LocalTensor(
+                    dtype=meta["dtype"],
+                    shape=tuple(meta["shape"]),
+                    data_range=LocalTensorRange(
+                        filename,
+                        data_start_offset + meta["data_offsets"][0],
+                        meta["data_offsets"][1] - meta["data_offsets"][0],
+                    ),
+                )
+
+            # order by name (same as default safetensors behavior)
+            # ref: https://github.com/huggingface/safetensors/blob/0816a1ae1d6b731cefd67f061d80d1cadd0dd7bb/bindings/python/src/lib.rs#L606
+            self.tensors = dict(sorted(tensors.items(), key=lambda t: t[0]))
+
+    def __enter__(self, *args, **kwargs):
+        del args, kwargs  # unused
+        return self.tensors
+
+    def __exit__(self, *args, **kwargs):
+        del args, kwargs  # unused