]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
convert.py: Fix loading safetensors and ggml format on Windows (#991)
authorcomex <redacted>
Sat, 15 Apr 2023 21:53:21 +0000 (14:53 -0700)
committerGitHub <redacted>
Sat, 15 Apr 2023 21:53:21 +0000 (23:53 +0200)
Calling `mmap.mmap` on Windows apparently resets the file offset of the
raw file object (and makes the BufferedReader return a *negative* file
offset).  For safetensors, avoid using the file offset after calling
mmap.  For GGML format, explicitly save and restore the offset.

Fixes #966.

convert.py

index 056dc618daa481da3a05a9e7a6928bc0aad4517f..4e28a45ebb4e8bddb7198eb947a899feba58a518 100644 (file)
@@ -735,7 +735,7 @@ def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
     header: Dict[str, Dict[str, Any]] = json.loads(fp.read(header_size))
     # Use mmap for the actual data to avoid race conditions with the file offset.
     mapped = memoryview(mmap.mmap(fp.fileno(), 0, access=mmap.ACCESS_READ))
-    byte_buf = mapped[fp.tell():]
+    byte_buf = mapped[8 + header_size:]
 
     def convert(info: Dict[str, Any]) -> LazyTensor:
         data_type = SAFETENSORS_DATA_TYPES[info['dtype']]
@@ -761,7 +761,7 @@ def must_read(fp: IO[bytes], length: int) -> bytes:
     return ret
 
 
-def lazy_load_ggml_file(fp: IO[bytes], path: Path) -> ModelPlus:
+def lazy_load_ggml_file(fp: io.BufferedReader, path: Path) -> ModelPlus:
     magic = must_read(fp, 4)[::-1]
     if magic in (b'ggmf', b'ggjt'):
         version, = struct.unpack("i", must_read(fp, 4))
@@ -795,7 +795,9 @@ def lazy_load_ggml_file(fp: IO[bytes], path: Path) -> ModelPlus:
 
     model: LazyModel = {}
     # Use mmap for the actual data to avoid race conditions with the file offset.
+    off = fp.raw.tell()
     mapped = memoryview(mmap.mmap(fp.fileno(), 0, access=mmap.ACCESS_READ))
+    fp.raw.seek(off) # needed on Windows
 
     def read_tensor() -> None:  # this is a function so that variables captured in `load` don't change
         shape_len, name_len, ftype = struct.unpack("iii", must_read(fp, 12))