]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
convert.py: also look for plain model.safetensors (#4043)
authorafrideva <redacted>
Tue, 14 Nov 2023 01:03:40 +0000 (17:03 -0800)
committerGitHub <redacted>
Tue, 14 Nov 2023 01:03:40 +0000 (18:03 -0700)
* add safetensors to convert.py help message

* Check for single-file safetensors model

* Update convert.py "model" option help message

* revert convert.py help message change

convert.py

index a4b87e08849bcc1037c33d6cd6cb7cc652fd2e1b..3d6216f1d4e7abb376280d44fe04929ad8a1f26f 100755 (executable)
@@ -1036,7 +1036,8 @@ def load_some_model(path: Path) -> ModelPlus:
     # Be extra-friendly and accept either a file or a directory:
     if path.is_dir():
         # Check if it's a set of safetensors files first
-        files = list(path.glob("model-00001-of-*.safetensors"))
+        globs = ["model-00001-of-*.safetensors", "model.safetensors"]
+        files = [file for glob in globs for file in path.glob(glob)]
         if not files:
             # Try the PyTorch patterns too, with lower priority
             globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"]
@@ -1123,7 +1124,7 @@ def main(args_in: list[str] | None = None) -> None:
     parser.add_argument("--outtype",     choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
     parser.add_argument("--vocab-dir",   type=Path,              help="directory containing tokenizer.model, if separate from model file")
     parser.add_argument("--outfile",     type=Path,              help="path to write to; default: based on input")
-    parser.add_argument("model",         type=Path,              help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
+    parser.add_argument("model",         type=Path,              help="directory containing model file, or model file itself (*.pth, *.pt, *.bin, *.safetensors)")
     parser.add_argument("--vocabtype",   choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
     parser.add_argument("--ctx",         type=int,               help="model training context (default: based on input)")
     parser.add_argument("--concurrency", type=int,               help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)