]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
convert: add ability to convert safetensors files (#1276)
authorubik2 <redacted>
Mon, 8 May 2023 11:54:26 +0000 (04:54 -0700)
committerGitHub <redacted>
Mon, 8 May 2023 11:54:26 +0000 (13:54 +0200)
* when loading a safetensors file, ignore the metadata header
* check for safetensors files first, and only use PyTorch versions when safetensors aren't available

convert.py

index 126beaabc1b82da42f9a1a4897897a137f9177e0..8f4f0399e1c52dd1875441229dbc80b7ff783a92 100644 (file)
@@ -766,7 +766,7 @@ def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
             return UnquantizedTensor(np.frombuffer(buf, dtype=numpy_dtype).reshape(shape))
         description = f'safetensors begin={begin} end={end} type={data_type} path={path}'
         return LazyTensor(load, shape, data_type, description)
-    model = {name: convert(info) for (name, info) in header.items()}
+    model = {name: convert(info) for (name, info) in header.items() if name != '__metadata__'}
     return ModelPlus(model=model, paths=[path], format='safetensors', vocab=None)
 
 
@@ -1051,8 +1051,12 @@ def load_some_model(path: Path) -> ModelPlus:
     '''Load a model of any supported format.'''
     # Be extra-friendly and accept either a file or a directory:
     if path.is_dir():
-        globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"]
-        files = [file for glob in globs for file in path.glob(glob)]
+        # Check if it's a set of safetensors files first
+        files = list(path.glob("model-00001-of-*.safetensors"))
+        if not files:
+            # Try the PyTorch patterns too, with lower priority
+            globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"]
+            files = [file for glob in globs for file in path.glob(glob)]
         if not files:
             # Try GGML too, but with lower priority, since if both a non-GGML
             # model and a GGML model exist in the same directory, we assume the