]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Option to split during conversion (#6942)
authorChristian Zhou-Zheng <redacted>
Mon, 24 Jun 2024 09:42:03 +0000 (05:42 -0400)
committerGitHub <redacted>
Mon, 24 Jun 2024 09:42:03 +0000 (19:42 +1000)
* support splits in convert.py

* Support split by size and dry run to write estimated shards/filesizes

* Move split functionality to new GGUFManager class

* fix improper function signature

* tentative push of convert-hf-to-gguf support

* resolve merge + SplitArguments for easier parsing

* Fix eager tensor memory leak and remove convert.py changes

Removed a memory leak caused by unexpected reference retention to eager tensors.

Also removed GGUFManager functionality in convert.py in favor of specializing for convert-hf-to-gguf.py.

* refactor SplitStrategy to be a deque

Instead of having SplitStrategy have a `data` field that is a deque, just have SplitStrategy be a subclass of deque itself.

* fix Q8 quantization

* remove unnecessary imports in gguf_manager

* fix final? merge issue

* fix gguf_writer placement and remove comments

* oops, actually fix gguf_writer placement

* reduce duplicated code from gguf_writer

* further simplify GGUFManager

* simplify even further and standardize with GGUFWriter

* reduce diffs with master

* form shards while adding tensors, SHA256 sums agree with master

* re-add type hint

Co-authored-by: compilade <redacted>
* GGUFWriter compatibility fix

Co-authored-by: compilade <redacted>
* Shard dataclass and un-negative dont_add_architecture

* type consistency in format_n_bytes_to_str

* move kv keys to constants.py

* make pathlib explicit

* base-1024 bytes to base-1000

* rename GGUFManager to GGUFWriterSplit

* Update gguf-py/gguf/constants.py

Co-authored-by: compilade <redacted>
* fix convert-hf-to-gguf.py permissions

* fix line endings

* Update gguf-py/gguf/gguf_writer_split.py

Co-authored-by: compilade <redacted>
* convert-hf : restore executable file permission

* examples/convert-legacy-llama.py: restore executable file permission

* reinstate original gguf package import and fix type annotation

* attempt to appease the linter

* attempt 2 to appease the linter

* attempt 3 to appease the linter

* comma consistency

* Update convert-hf-to-gguf.py

Co-authored-by: compilade <redacted>
* edit cmd line args

* use simplification from #7827

* kv/ti data are still wrong

* try to refactor kv data (still fails)

* fix ti data messiness

* tidy up

* fix linting

* actually make the linter happy

* cleanup round 1

* remove SplitStrategy, SplitArguments

* appease linter

* fix typing and clean up

* fix linting

* Update gguf-py/gguf/gguf_writer.py

Co-authored-by: compilade <redacted>
* progress bar, fix split logic

* Update gguf-py/gguf/gguf_writer.py

Co-authored-by: compilade <redacted>
* catch oversights

* Update gguf-py/gguf/gguf_writer.py

Co-authored-by: compilade <redacted>
* Update gguf-py/gguf/gguf_writer.py

Co-authored-by: compilade <redacted>
* Update gguf-py/gguf/gguf_writer.py

Co-authored-by: compilade <redacted>
* Update gguf-py/gguf/gguf_writer.py

Co-authored-by: compilade <redacted>
* Update gguf-py/gguf/gguf_writer.py

Co-authored-by: compilade <redacted>
* swap bar orders

* Update gguf-py/gguf/gguf_writer.py

Co-authored-by: compilade <redacted>
* Update gguf-py/gguf/gguf_writer.py

Co-authored-by: compilade <redacted>
* compatibility fix

* Update gguf-py/gguf/gguf_writer.py

Co-authored-by: compilade <redacted>
* Update convert-hf-to-gguf.py

Co-authored-by: compilade <redacted>
---------

Co-authored-by: Brian <redacted>
Co-authored-by: compilade <redacted>
convert-hf-to-gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py

index ebfc1a1f94b7847c03844349291009a22d53c6ec..c26fad9307f15aa3e2b66ebfc35638272ce1727f 100755 (executable)
@@ -65,7 +65,8 @@ class Model:
     # subclasses should define this!
     model_arch: gguf.MODEL_ARCH
 
-    def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, model_name: str | None):
+    def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool,
+                 model_name: str | None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
         if type(self) is Model:
             raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
         self.dir_model = dir_model
@@ -96,7 +97,8 @@ class Model:
         ftype_lw: str = ftype_up.lower()
         # allow templating the file name with the output ftype, useful with the "auto" ftype
         self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
-        self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
+        self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
+                                           split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
 
     @classmethod
     def __init_subclass__(cls):
@@ -332,6 +334,8 @@ class Model:
         self.gguf_writer.close()
 
     def write_vocab(self):
+        if len(self.gguf_writer.tensors) != 1:
+            raise ValueError('Splitting the vocabulary is not supported')
         self.gguf_writer.write_header_to_file(self.fname_out)
         self.gguf_writer.write_kv_data_to_file()
         self.gguf_writer.close()
@@ -2974,10 +2978,44 @@ def parse_args() -> argparse.Namespace:
         "--verbose", action="store_true",
         help="increase output verbosity",
     )
+    parser.add_argument(
+        "--split-max-tensors", type=int, default=0,
+        help="max tensors in each split",
+    )
+    parser.add_argument(
+        "--split-max-size", type=str, default="0",
+        help="max size per split N(M|G)",
+    )
+    parser.add_argument(
+        "--dry-run", action="store_true",
+        help="only print out a split plan and exit, without writing any new files",
+    )
+    parser.add_argument(
+        "--no-tensor-first-split", action="store_true",
+        help="do not add tensors to the first split (disabled by default)"
+    )
 
     return parser.parse_args()
 
 
+def split_str_to_n_bytes(split_str: str) -> int:
+    if split_str.endswith("K"):
+        n = int(split_str[:-1]) * 1000
+    elif split_str.endswith("M"):
+        n = int(split_str[:-1]) * 1000 * 1000
+    elif split_str.endswith("G"):
+        n = int(split_str[:-1]) * 1000 * 1000 * 1000
+    elif split_str.isnumeric():
+        n = int(split_str)
+    else:
+        raise ValueError(f"Invalid split size: {split_str}, must be a number, optionally followed by K, M, or G")
+
+    if n < 0:
+        raise ValueError(f"Invalid split size: {split_str}, must be positive")
+
+    return n
+
+
 def main() -> None:
     args = parse_args()
 
@@ -3010,6 +3048,10 @@ def main() -> None:
         "auto": gguf.LlamaFileType.GUESSED,
     }
 
+    if args.use_temp_file and (args.split_max_tensors > 0 or args.split_max_size != "0"):
+        logger.error("Error: Cannot use temp file when splitting")
+        sys.exit(1)
+
     if args.outfile is not None:
         fname_out = args.outfile
     else:
@@ -3027,7 +3069,10 @@ def main() -> None:
             logger.error(f"Model {hparams['architectures'][0]} is not supported")
             sys.exit(1)
 
-        model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy, args.model_name)
+        model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file,
+                                     args.no_lazy, args.model_name, split_max_tensors=args.split_max_tensors,
+                                     split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
+                                     small_first_shard=args.no_tensor_first_split)
 
         logger.info("Set model parameters")
         model_instance.set_gguf_parameters()
@@ -3038,13 +3083,13 @@ def main() -> None:
         model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
 
         if args.vocab_only:
-            logger.info(f"Exporting model vocab to '{model_instance.fname_out}'")
+            logger.info("Exporting model vocab...")
             model_instance.write_vocab()
+            logger.info("Model vocab successfully exported.")
         else:
-            logger.info(f"Exporting model to '{model_instance.fname_out}'")
+            logger.info("Exporting model...")
             model_instance.write()
-
-        logger.info(f"Model successfully exported to '{model_instance.fname_out}'")
+            logger.info("Model successfully exported.")
 
 
 if __name__ == '__main__':
index 100594b3035262ac91e77b0b0cd2810bf5a4a224..222a2d137b08f672b2a8994363d4980803f25a27 100644 (file)
@@ -75,6 +75,11 @@ class Keys:
         SCALING_FINETUNED       = "{arch}.rope.scaling.finetuned"
         SCALING_YARN_LOG_MUL    = "{arch}.rope.scaling.yarn_log_multiplier"
 
+    class Split:
+        LLM_KV_SPLIT_NO            = "split.no"
+        LLM_KV_SPLIT_COUNT         = "split.count"
+        LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"
+
     class SSM:
         CONV_KERNEL    = "{arch}.ssm.conv_kernel"
         INNER_SIZE     = "{arch}.ssm.inner_size"
index 3b841a62583a37aa155be10ff0216002b6a05594..9869f6fe3445ab15cba2849684666dc3688210b2 100644 (file)
@@ -7,6 +7,7 @@ import struct
 import tempfile
 from dataclasses import dataclass
 from enum import Enum, auto
+from pathlib import Path
 from io import BufferedWriter
 from typing import IO, Any, Sequence, Mapping
 from string import ascii_letters, digits
@@ -31,6 +32,9 @@ from .quants import quant_shape_from_byte_shape
 logger = logging.getLogger(__name__)
 
 
+SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
+
+
 @dataclass
 class TensorInfo:
     shape: Sequence[int]
@@ -55,11 +59,11 @@ class WriterState(Enum):
 
 
 class GGUFWriter:
-    fout: BufferedWriter | None
-    path: os.PathLike[str] | str | None
+    fout: list[BufferedWriter] | None
+    path: Path | None
     temp_file: tempfile.SpooledTemporaryFile[bytes] | None
-    tensors: dict[str, TensorInfo]
-    kv_data: dict[str, GGUFValue]
+    tensors: list[dict[str, TensorInfo]]
+    kv_data: list[dict[str, GGUFValue]]
     state: WriterState
     _simple_value_packing = {
         GGUFValueType.UINT8:   "B",
@@ -76,26 +80,38 @@ class GGUFWriter:
     }
 
     def __init__(
-        self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False,
-        endianess: GGUFEndian = GGUFEndian.LITTLE,
+        self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
+        split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False
     ):
         self.fout = None
-        self.path = path
+        self.path = Path(path) if path else None
         self.arch = arch
         self.endianess = endianess
         self.data_alignment = GGUF_DEFAULT_ALIGNMENT
         self.use_temp_file = use_temp_file
         self.temp_file = None
-        self.tensors = dict()
-        self.kv_data = dict()
+        self.tensors = [{}]
+        self.kv_data = [{}]
+        self.split_max_tensors = split_max_tensors
+        self.split_max_size = split_max_size
+        self.dry_run = dry_run
+        self.small_first_shard = small_first_shard
         logger.info("gguf: This GGUF file is for {0} Endian only".format(
             "Big" if self.endianess == GGUFEndian.BIG else "Little",
         ))
         self.state = WriterState.NO_FILE
 
+        if self.small_first_shard:
+            self.tensors.append({})
+
         self.add_architecture()
 
-    def open_output_file(self, path: os.PathLike[str] | str | None = None) -> None:
+    def format_shard_names(self, path: Path) -> list[Path]:
+        if len(self.tensors) == 1:
+            return [path]
+        return [path.with_name(SHARD_NAME_FORMAT.format(path.stem, i + 1, len(self.tensors))) for i in range(len(self.tensors))]
+
+    def open_output_file(self, path: Path | None = None) -> None:
         if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
             # allow calling this multiple times as long as the path is the same
             return
@@ -106,22 +122,58 @@ class GGUFWriter:
             self.path = path
 
         if self.path is not None:
-            if self.fout is not None:
-                self.fout.close()
-            self.fout = open(self.path, "wb")
+            filenames = self.print_plan()
+            self.fout = [open(filename, "wb") for filename in filenames]
             self.state = WriterState.EMPTY
 
-    def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
+    def print_plan(self) -> list[Path]:
+        logger.info("Writing the following files:")
+        assert self.path is not None
+        filenames = self.format_shard_names(self.path)
+        assert len(filenames) == len(self.tensors)
+        for name, tensors in zip(filenames, self.tensors):
+            logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}")
+
+        if self.dry_run:
+            logger.info("Dry run, not writing files")
+            exit()
+
+        return filenames
+
+    def add_shard_kv_data(self) -> None:
+        if len(self.tensors) == 1:
+            return
+
+        total_tensors = sum(len(t) for t in self.tensors)
+        assert self.fout is not None
+        total_splits = len(self.fout)
+        self.kv_data.extend({} for _ in range(len(self.kv_data), total_splits))
+        for i, kv_data in enumerate(self.kv_data):
+            kv_data[Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16)
+            kv_data[Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(total_splits, GGUFValueType.UINT16)
+            kv_data[Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32)
+
+    def write_header_to_file(self, path: Path | None = None) -> None:
+        if len(self.tensors) == 1 and (self.split_max_tensors != 0 or self.split_max_size != 0):
+            logger.warning("Model fails split requirements, not splitting")
+
         self.open_output_file(path)
 
         if self.state is not WriterState.EMPTY:
             raise ValueError(f'Expected output file to be empty, got {self.state}')
 
-        self._write_packed("<I", GGUF_MAGIC, skip_pack_prefix = True)
-        self._write_packed("I", GGUF_VERSION)
-        self._write_packed("Q", len(self.tensors))
-        self._write_packed("Q", len(self.kv_data))
-        self.flush()
+        assert self.fout is not None
+        assert len(self.fout) == len(self.tensors)
+        assert len(self.kv_data) == 1
+
+        self.add_shard_kv_data()
+
+        for fout, tensors, kv_data in zip(self.fout, self.tensors, self.kv_data):
+            fout.write(self._pack("<I", GGUF_MAGIC, skip_pack_prefix = True))
+            fout.write(self._pack("I", GGUF_VERSION))
+            fout.write(self._pack("Q", len(tensors)))
+            fout.write(self._pack("Q", len(kv_data)))
+            fout.flush()
         self.state = WriterState.HEADER
 
     def write_kv_data_to_file(self) -> None:
@@ -129,13 +181,15 @@ class GGUFWriter:
             raise ValueError(f'Expected output file to contain the header, got {self.state}')
         assert self.fout is not None
 
-        kv_data = bytearray()
+        for fout, kv_data in zip(self.fout, self.kv_data):
+            kv_bytes = bytearray()
 
-        for key, val in self.kv_data.items():
-            kv_data += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
-            kv_data += self._pack_val(val.value, val.type, add_vtype=True)
+            for key, val in kv_data.items():
+                kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
+                kv_bytes += self._pack_val(val.value, val.type, add_vtype=True)
+
+            fout.write(kv_bytes)
 
-        self.fout.write(kv_data)
         self.flush()
         self.state = WriterState.KV_DATA
 
@@ -144,28 +198,29 @@ class GGUFWriter:
             raise ValueError(f'Expected output file to contain KV data, got {self.state}')
         assert self.fout is not None
 
-        ti_data = bytearray()
-        offset_tensor = 0
-
-        for name, ti in self.tensors.items():
-            ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
-            n_dims = len(ti.shape)
-            ti_data += self._pack("I", n_dims)
-            for i in range(n_dims):
-                ti_data += self._pack("Q", ti.shape[n_dims - 1 - i])
-            ti_data += self._pack("I", ti.dtype)
-            ti_data += self._pack("Q", offset_tensor)
-            offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)
-
-        self.fout.write(ti_data)
-        self.flush()
+        for fout, tensors in zip(self.fout, self.tensors):
+            ti_data = bytearray()
+            offset_tensor = 0
+
+            for name, ti in tensors.items():
+                ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
+                n_dims = len(ti.shape)
+                ti_data += self._pack("I", n_dims)
+                for j in range(n_dims):
+                    ti_data += self._pack("Q", ti.shape[n_dims - 1 - j])
+                ti_data += self._pack("I", ti.dtype)
+                ti_data += self._pack("Q", offset_tensor)
+                offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)
+
+            fout.write(ti_data)
+            fout.flush()
         self.state = WriterState.TI_DATA
 
     def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
-        if key in self.kv_data:
+        if any(key in kv_data for kv_data in self.kv_data):
             raise ValueError(f'Duplicated key name {key!r}')
 
-        self.kv_data[key] = GGUFValue(value=val, type=vtype)
+        self.kv_data[0][key] = GGUFValue(value=val, type=vtype)
 
     def add_uint8(self, key: str, val: int) -> None:
         self.add_key_value(key,val, GGUFValueType.UINT8)
@@ -206,9 +261,6 @@ class GGUFWriter:
         self.add_key_value(key, val, GGUFValueType.STRING)
 
     def add_array(self, key: str, val: Sequence[Any]) -> None:
-        if not isinstance(val, Sequence):
-            raise ValueError("Value must be a sequence for array type")
-
         self.add_key_value(key, val, GGUFValueType.ARRAY)
 
     @staticmethod
@@ -222,7 +274,7 @@ class GGUFWriter:
         if self.state is not WriterState.NO_FILE:
             raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
 
-        if name in self.tensors:
+        if any(name in tensors for tensors in self.tensors):
             raise ValueError(f'Duplicated tensor name {name!r}')
 
         if raw_dtype is None:
@@ -247,7 +299,18 @@ class GGUFWriter:
             if tensor_dtype == np.uint8:
                 tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
 
-        self.tensors[name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes)
+        # make sure there is at least one tensor before splitting
+        if len(self.tensors[-1]) > 0:
+            if (  # split when over tensor limit
+                self.split_max_tensors != 0
+                and len(self.tensors[-1]) >= self.split_max_tensors
+            ) or (   # split when over size limit
+                self.split_max_size != 0
+                and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size
+            ):
+                self.tensors.append({})
+
+        self.tensors[-1][name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes)
 
     def add_tensor(
         self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
@@ -264,7 +327,7 @@ class GGUFWriter:
         self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype)
 
         if self.temp_file is None:
-            self.tensors[name].tensor = tensor
+            self.tensors[-1][name].tensor = tensor
             return
 
         tensor.tofile(self.temp_file)
@@ -282,9 +345,24 @@ class GGUFWriter:
 
         if self.endianess == GGUFEndian.BIG:
             tensor.byteswap(inplace=True)
-        self.write_padding(self.fout, self.fout.tell())
-        tensor.tofile(self.fout)
-        self.write_padding(self.fout, tensor.nbytes)
+
+        file_id = -1
+        for i, tensors in enumerate(self.tensors):
+            if len(tensors) > 0:
+                file_id = i
+                break
+
+        fout = self.fout[file_id]
+
+        # pop the first tensor info
+        # TODO: cleaner way to get the first key
+        first_tensor_name = [name for name, _ in zip(self.tensors[file_id].keys(), range(1))][0]
+        ti = self.tensors[file_id].pop(first_tensor_name)
+        assert ti.nbytes == tensor.nbytes
+
+        self.write_padding(fout, fout.tell())
+        tensor.tofile(fout)
+        self.write_padding(fout, tensor.nbytes)
 
         self.state = WriterState.WEIGHTS
 
@@ -293,31 +371,43 @@ class GGUFWriter:
 
         assert self.fout is not None
 
-        self.write_padding(self.fout, self.fout.tell())
+        for fout in self.fout:
+            self.write_padding(fout, fout.tell())
 
         if self.temp_file is None:
+            shard_bar = None
             bar = None
 
             if progress:
                 from tqdm import tqdm
 
-                total_bytes = sum(t.nbytes for t in self.tensors.values())
+                total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values())
 
+                if len(self.fout) > 1:
+                    shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True)
                 bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
 
-            # relying on the fact that Python dicts preserve insertion order (since 3.7)
-            for ti in self.tensors.values():
-                assert ti.tensor is not None  # can only iterate once over the tensors
-                assert ti.tensor.nbytes == ti.nbytes
-                ti.tensor.tofile(self.fout)
-                if bar is not None:
-                    bar.update(ti.nbytes)
-                self.write_padding(self.fout, ti.nbytes)
-                ti.tensor = None
+            for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)):
+                if shard_bar is not None:
+                    shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})")
+                    total = sum(ti.nbytes for ti in tensors.values())
+                    shard_bar.reset(total=(total if total > 0 else None))
+
+                # relying on the fact that Python dicts preserve insertion order (since 3.7)
+                for ti in tensors.values():
+                    assert ti.tensor is not None  # can only iterate once over the tensors
+                    assert ti.tensor.nbytes == ti.nbytes
+                    ti.tensor.tofile(fout)
+                    if shard_bar is not None:
+                        shard_bar.update(ti.nbytes)
+                    if bar is not None:
+                        bar.update(ti.nbytes)
+                    self.write_padding(fout, ti.nbytes)
+                    ti.tensor = None
         else:
             self.temp_file.seek(0)
 
-            shutil.copyfileobj(self.temp_file, self.fout)
+            shutil.copyfileobj(self.temp_file, self.fout[0 if not self.small_first_shard else 1])
             self.flush()
             self.temp_file.close()
 
@@ -325,11 +415,13 @@ class GGUFWriter:
 
     def flush(self) -> None:
         assert self.fout is not None
-        self.fout.flush()
+        for fout in self.fout:
+            fout.flush()
 
     def close(self) -> None:
         if self.fout is not None:
-            self.fout.close()
+            for fout in self.fout:
+                fout.close()
             self.fout = None
 
     def add_architecture(self) -> None:
@@ -626,6 +718,13 @@ class GGUFWriter:
 
         return kv_data
 
-    def _write_packed(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None:
-        assert self.fout is not None
-        self.fout.write(self._pack(fmt, value, skip_pack_prefix))
+    @staticmethod
+    def format_n_bytes_to_str(num: int) -> str:
+        if num == 0:
+            return "negligible - metadata only"
+        fnum = float(num)
+        for unit in ("", "K", "M", "G"):
+            if abs(fnum) < 1000.0:
+                return f"{fnum:3.1f}{unit}"
+            fnum /= 1000.0
+        return f"{fnum:.1f}T - over 1TB, split recommended"