]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
gguf-py: Refactor and allow reading/modifying existing GGUF files (#3981)
authorKerfuffle <redacted>
Sat, 11 Nov 2023 05:04:50 +0000 (22:04 -0700)
committerGitHub <redacted>
Sat, 11 Nov 2023 05:04:50 +0000 (08:04 +0300)
* gguf-py: Refactor and add file reading support

* Replay changes from #3871

Credit to @cebtenzzre for that pull

* Various type annotation fixes.

* sort imports with isort (again)

* Fix missing return statement in add_tensor

* style cleanup with flake8

* fix NamedTuple and Enum usage

* Fix an issue with state init in GGUFReader

Move examples to an examples/ directory

Clean up examples

Add an example of modifying keys in a GGUF file

Update documentation with info on examples

Try to support people importing gguf/gguf.py directly

* Damagage is not a word.

* Clean up gguf-py/examples/modify_gguf.py whitespace

Co-authored-by: Jared Van Bortel <redacted>
* Update gguf-py/examples/modify_gguf.py formatting

Co-authored-by: Jared Van Bortel <redacted>
* Update gguf-py/gguf/gguf_reader.py type hint

Co-authored-by: Jared Van Bortel <redacted>
* Make examples executable, formatting changes

* Add more information to GGUFReader and examples comments

* Include a gguf Python package version bump

* Add convert-gguf-endian.py script

* cleanup

* gguf-py : bump minor version

* Reorganize scripts

* Make GGUFReader endian detection less arbitrary

* Add JSON dumping support to gguf-dump.py

Which I kind of regret now

* A few for gguf-dump.py cleanups

* Murder accidental tuple in gguf-py/scripts/gguf-dump.py

Co-authored-by: Jared Van Bortel <redacted>
* cleanup

* constants : remove unneeded type annotations

* fix python 3.8 compat

* Set up gguf- scripts in pyproject.toml

* And include scripts/__init__.py, derp

* convert.py: We can't currently support Q8_0 on big endian.

* gguf-py: SpecialVocab: Always try available sources for special token ids

gguf-py: SpecialVocab: Try to load merges from merges.txt if not in tokenizer.json

gguf-py: SpecialVocab: Add 'add_bos_token' type bools to GGUF metadata
u

* cleanup

* Promote add_X_token to GGUF metadata for BOS and EOS

---------

Co-authored-by: Jared Van Bortel <redacted>
Co-authored-by: Jared Van Bortel <redacted>
20 files changed:
convert-baichuan-hf-to-gguf.py
convert-llama-ggml-to-gguf.py
convert-persimmon-to-gguf.py
convert.py
examples/train-text-from-scratch/convert-train-checkpoint-to-gguf.py
gguf-py/README.md
gguf-py/examples/writer.py [new file with mode: 0755]
gguf-py/gguf/__init__.py
gguf-py/gguf/constants.py [new file with mode: 0644]
gguf-py/gguf/gguf.py
gguf-py/gguf/gguf_reader.py [new file with mode: 0644]
gguf-py/gguf/gguf_writer.py [new file with mode: 0644]
gguf-py/gguf/tensor_mapping.py [new file with mode: 0644]
gguf-py/gguf/vocab.py [new file with mode: 0644]
gguf-py/pyproject.toml
gguf-py/scripts/__init__.py [new file with mode: 0644]
gguf-py/scripts/gguf-convert-endian.py [new file with mode: 0755]
gguf-py/scripts/gguf-dump.py [new file with mode: 0755]
gguf-py/scripts/gguf-set-metadata.py [new file with mode: 0755]
gguf-py/tests/test_gguf.py

index 67ccbe99f132af8bc4a7179d21ccbfa4112f5dc2..789602351ca9d74709d466d528c9e87aab61c990 100755 (executable)
@@ -16,7 +16,7 @@ import torch
 from sentencepiece import SentencePieceProcessor  # type: ignore[import]
 
 if 'NO_LOCAL_GGUF' not in os.environ:
-    sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
+    sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
 import gguf
 
 
index 871add64d4ca737d70f7185e8336f1707c6de6b2..d898d81c4c445b1227937c508a399920815f0021 100755 (executable)
@@ -12,29 +12,9 @@ import numpy as np
 
 import os
 if 'NO_LOCAL_GGUF' not in os.environ:
-    sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
+    sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
 import gguf
 
-# Note: Does not support GGML_QKK_64
-QK_K = 256
-# Items here are (block size, type size)
-GGML_QUANT_SIZES = {
-    gguf.GGMLQuantizationType.F32  : (1, 4),
-    gguf.GGMLQuantizationType.F16  : (1, 2),
-    gguf.GGMLQuantizationType.Q4_0 : (32, 2 + 16),
-    gguf.GGMLQuantizationType.Q4_1 : (32, 2 + 2 + 16),
-    gguf.GGMLQuantizationType.Q5_0 : (32, 2 + 4 + 16),
-    gguf.GGMLQuantizationType.Q5_1 : (32, 2 + 2 + 4 + 16),
-    gguf.GGMLQuantizationType.Q8_0 : (32, 2 + 32),
-    gguf.GGMLQuantizationType.Q8_1 : (32, 4 + 4 + 32),
-    gguf.GGMLQuantizationType.Q2_K : (256, 2 + 2 + QK_K // 16 + QK_K // 4),
-    gguf.GGMLQuantizationType.Q3_K : (256, 2 + QK_K // 4 + QK_K // 8 + 12),
-    gguf.GGMLQuantizationType.Q4_K : (256, 2 + 2 + QK_K // 2 + 12),
-    gguf.GGMLQuantizationType.Q5_K : (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12),
-    gguf.GGMLQuantizationType.Q6_K : (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16),
-    gguf.GGMLQuantizationType.Q8_K : (256, 4 + QK_K + QK_K // 8),
-}
-
 class GGMLFormat(IntEnum):
     GGML = 0
     GGMF = 1
@@ -125,7 +105,7 @@ class Tensor:
         (n_dims, name_len, dtype) = struct.unpack('<3I', data[offset:offset + 12])
         assert n_dims >= 0 and n_dims <= 4, f'Invalid tensor dimensions {n_dims}'
         assert name_len < 4096, 'Absurd tensor name length'
-        quant = GGML_QUANT_SIZES.get(dtype)
+        quant = gguf.GGML_QUANT_SIZES.get(dtype)
         assert quant is not None, 'Unknown tensor type'
         (blksize, tysize) = quant
         offset += 12
index e022ffe46189e5160f2d60c2a1859a09c3a69e7e..240f87306e5783380fab0e43bec47099e103a15e 100644 (file)
@@ -6,7 +6,7 @@ import argparse
 from pathlib import Path
 from sentencepiece import SentencePieceProcessor
 if 'NO_LOCAL_GGUF' not in os.environ:
-    sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
+    sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
 import gguf
 
 def _flatten_dict(dct, tensors, prefix=None):
index b0f44dbef8332a83510e492c2267690b31993ae7..a4b87e08849bcc1037c33d6cd6cb7cc652fd2e1b 100755 (executable)
@@ -3,11 +3,9 @@ from __future__ import annotations
 
 import argparse
 import concurrent.futures
-import copy
 import enum
 import faulthandler
 import functools
-import io
 import itertools
 import json
 import math
@@ -23,14 +21,14 @@ from abc import ABCMeta, abstractmethod
 from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
 from dataclasses import dataclass
 from pathlib import Path
-from typing import IO, TYPE_CHECKING, Any, Callable, Generator, Iterable, Literal, Sequence, TypeVar
+from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar
 
 import numpy as np
 from sentencepiece import SentencePieceProcessor
 
 import os
 if 'NO_LOCAL_GGUF' not in os.environ:
-    sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
+    sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
 import gguf
 
 if TYPE_CHECKING:
@@ -851,7 +849,7 @@ class OutputFile:
         elif isinstance(vocab, BpeVocab):
             self.gguf.add_tokenizer_model("gpt2")
         else:
-            raise ValueError(f'Unknown vocab type: Not BpeVocab or SentencePieceVocab')
+            raise ValueError('Unknown vocab type: Not BpeVocab or SentencePieceVocab')
         self.gguf.add_token_list(tokens)
         self.gguf.add_token_scores(scores)
         self.gguf.add_token_types(toktypes)
@@ -905,7 +903,7 @@ class OutputFile:
         return dt.quantize(arr)
 
     @staticmethod
-    def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY, endianess=gguf.GGUFEndian.LITTLE) -> None:
+    def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None:
         check_vocab_size(params, vocab)
 
         of = OutputFile(fname_out, endianess=endianess)
@@ -1114,11 +1112,15 @@ def do_dump_model(model_plus: ModelPlus) -> None:
 
 
 def main(args_in: list[str] | None = None) -> None:
+    output_choices = ["f32", "f16"]
+    if np.uint32(1) == np.uint32(1).newbyteorder("<"):
+        # We currently only support Q8_0 output on little endian systems.
+        output_choices.append("q8_0")
     parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file")
     parser.add_argument("--dump",        action="store_true",    help="don't convert, just show what's in the model")
     parser.add_argument("--dump-single", action="store_true",    help="don't convert, just show what's in a single model file")
     parser.add_argument("--vocab-only",  action="store_true",    help="extract only the vocab")
-    parser.add_argument("--outtype",     choices=["f32", "f16", "q8_0"], help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
+    parser.add_argument("--outtype",     choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
     parser.add_argument("--vocab-dir",   type=Path,              help="directory containing tokenizer.model, if separate from model file")
     parser.add_argument("--outfile",     type=Path,              help="path to write to; default: based on input")
     parser.add_argument("model",         type=Path,              help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
index 887ed2e212786d4f09f2f0721beca1f67bc6871a..ed93673bcf306aeb754d08677d56db02e2794bf1 100644 (file)
@@ -9,7 +9,7 @@ import numpy as np
 from pathlib import Path
 
 if 'NO_LOCAL_GGUF' not in os.environ:
-    sys.path.insert(1, str(Path(__file__).parent / '..' / '..' / 'gguf-py' / 'gguf'))
+    sys.path.insert(1, str(Path(__file__).parent / '..' / '..' / 'gguf-py'))
 import gguf
 
 # gguf constants
index a28d8c57adc7d010a108a94a150b6511c64c9eca..502b6a510cc70dcb2d62288b94b9b37fe1e72949 100644 (file)
@@ -11,6 +11,16 @@ as an example for its usage.
 pip install gguf
 ```
 
+## API Examples/Simple Tools
+
+[examples/writer.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/examples/writer.py) — Generates `example.gguf` in the current directory to demonstrate generating a GGUF file. Note that this file cannot be used as a model.
+
+[scripts/gguf-dump.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf-dump.py) — Dumps a GGUF file's metadata to the console.
+
+[scripts/gguf-set-metadata.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf-set-metadata.py) — Allows changing simple metadata values in a GGUF file by key.
+
+[scripts/gguf-convert-endian.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf-convert-endian.py) — Allows converting the endianness of GGUF files.
+
 ## Development
 Maintainers who participate in development of this package are advised to install it in editable mode:
 
diff --git a/gguf-py/examples/writer.py b/gguf-py/examples/writer.py
new file mode 100755 (executable)
index 0000000..f39eed1
--- /dev/null
@@ -0,0 +1,40 @@
+#!/usr/bin/env python3
+import sys
+from pathlib import Path
+
+import numpy as np
+
+# Necessary to load the local gguf package
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+from gguf import GGUFWriter  # noqa: E402
+
+
+# Example usage:
+def writer_example() -> None:
+    # Example usage with a file
+    gguf_writer = GGUFWriter("example.gguf", "llama")
+
+    gguf_writer.add_architecture()
+    gguf_writer.add_block_count(12)
+    gguf_writer.add_uint32("answer", 42)  # Write a 32-bit integer
+    gguf_writer.add_float32("answer_in_float", 42.0)  # Write a 32-bit float
+    gguf_writer.add_custom_alignment(64)
+
+    tensor1 = np.ones((32,), dtype=np.float32) * 100.0
+    tensor2 = np.ones((64,), dtype=np.float32) * 101.0
+    tensor3 = np.ones((96,), dtype=np.float32) * 102.0
+
+    gguf_writer.add_tensor("tensor1", tensor1)
+    gguf_writer.add_tensor("tensor2", tensor2)
+    gguf_writer.add_tensor("tensor3", tensor3)
+
+    gguf_writer.write_header_to_file()
+    gguf_writer.write_kv_data_to_file()
+    gguf_writer.write_tensors_to_file()
+
+    gguf_writer.close()
+
+
+if __name__ == '__main__':
+    writer_example()
index f9b70a85b875e3626c518c321156abc5588b8ffe..110ab342ccd719fe9617b73c51c0f38ca6c7872f 100644 (file)
@@ -1 +1,5 @@
-from .gguf import *
+from .constants import *
+from .gguf_reader import *
+from .gguf_writer import *
+from .tensor_mapping import *
+from .vocab import *
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
new file mode 100644 (file)
index 0000000..bf1ccf6
--- /dev/null
@@ -0,0 +1,470 @@
+from __future__ import annotations
+
+import sys
+from enum import Enum, IntEnum, auto
+from typing import Any
+
+#
+# constants
+#
+
+GGUF_MAGIC             = 0x46554747  # "GGUF"
+GGUF_VERSION           = 3
+GGUF_DEFAULT_ALIGNMENT = 32
+
+#
+# metadata keys
+#
+
+
+class Keys:
+    class General:
+        ARCHITECTURE         = "general.architecture"
+        QUANTIZATION_VERSION = "general.quantization_version"
+        ALIGNMENT            = "general.alignment"
+        NAME                 = "general.name"
+        AUTHOR               = "general.author"
+        URL                  = "general.url"
+        DESCRIPTION          = "general.description"
+        LICENSE              = "general.license"
+        SOURCE_URL           = "general.source.url"
+        SOURCE_HF_REPO       = "general.source.huggingface.repository"
+        FILE_TYPE            = "general.file_type"
+
+    class LLM:
+        CONTEXT_LENGTH        = "{arch}.context_length"
+        EMBEDDING_LENGTH      = "{arch}.embedding_length"
+        BLOCK_COUNT           = "{arch}.block_count"
+        FEED_FORWARD_LENGTH   = "{arch}.feed_forward_length"
+        USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
+        TENSOR_DATA_LAYOUT    = "{arch}.tensor_data_layout"
+
+    class Attention:
+        HEAD_COUNT        = "{arch}.attention.head_count"
+        HEAD_COUNT_KV     = "{arch}.attention.head_count_kv"
+        MAX_ALIBI_BIAS    = "{arch}.attention.max_alibi_bias"
+        CLAMP_KQV         = "{arch}.attention.clamp_kqv"
+        LAYERNORM_EPS     = "{arch}.attention.layer_norm_epsilon"
+        LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
+
+    class Rope:
+        DIMENSION_COUNT      = "{arch}.rope.dimension_count"
+        FREQ_BASE            = "{arch}.rope.freq_base"
+        SCALING_TYPE         = "{arch}.rope.scaling.type"
+        SCALING_FACTOR       = "{arch}.rope.scaling.factor"
+        SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
+        SCALING_FINETUNED    = "{arch}.rope.scaling.finetuned"
+
+    class Tokenizer:
+        MODEL      = "tokenizer.ggml.model"
+        LIST       = "tokenizer.ggml.tokens"
+        TOKEN_TYPE = "tokenizer.ggml.token_type"
+        SCORES     = "tokenizer.ggml.scores"
+        MERGES     = "tokenizer.ggml.merges"
+        BOS_ID     = "tokenizer.ggml.bos_token_id"
+        EOS_ID     = "tokenizer.ggml.eos_token_id"
+        UNK_ID     = "tokenizer.ggml.unknown_token_id"
+        SEP_ID     = "tokenizer.ggml.seperator_token_id"
+        PAD_ID     = "tokenizer.ggml.padding_token_id"
+        ADD_BOS    = "tokenizer.ggml.add_bos_token"
+        ADD_EOS    = "tokenizer.ggml.add_eos_token"
+        HF_JSON    = "tokenizer.huggingface.json"
+        RWKV       = "tokenizer.rwkv.world"
+
+
+#
+# recommended mapping of model tensor names for storage in gguf
+#
+
+
+class MODEL_ARCH(IntEnum):
+    LLAMA     = auto()
+    FALCON    = auto()
+    BAICHUAN  = auto()
+    GPT2      = auto()
+    GPTJ      = auto()
+    GPTNEOX   = auto()
+    MPT       = auto()
+    STARCODER = auto()
+    PERSIMMON = auto()
+    REFACT    = auto()
+    BERT      = auto()
+    BLOOM     = auto()
+
+
+class MODEL_TENSOR(IntEnum):
+    TOKEN_EMBD      = auto()
+    TOKEN_EMBD_NORM = auto()
+    TOKEN_TYPES     = auto()
+    POS_EMBD        = auto()
+    OUTPUT          = auto()
+    OUTPUT_NORM     = auto()
+    ROPE_FREQS      = auto()
+    ATTN_Q          = auto()
+    ATTN_K          = auto()
+    ATTN_V          = auto()
+    ATTN_QKV        = auto()
+    ATTN_OUT        = auto()
+    ATTN_NORM       = auto()
+    ATTN_NORM_2     = auto()
+    ATTN_ROT_EMBD   = auto()
+    FFN_GATE        = auto()
+    FFN_DOWN        = auto()
+    FFN_UP          = auto()
+    FFN_NORM        = auto()
+    ATTN_Q_NORM     = auto()
+    ATTN_K_NORM     = auto()
+
+
+MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
+    MODEL_ARCH.LLAMA:          "llama",
+    MODEL_ARCH.FALCON:         "falcon",
+    MODEL_ARCH.BAICHUAN:       "baichuan",
+    MODEL_ARCH.GPT2:           "gpt2",
+    MODEL_ARCH.GPTJ:           "gptj",
+    MODEL_ARCH.GPTNEOX:        "gptneox",
+    MODEL_ARCH.MPT:            "mpt",
+    MODEL_ARCH.STARCODER:      "starcoder",
+    MODEL_ARCH.PERSIMMON:      "persimmon",
+    MODEL_ARCH.REFACT:         "refact",
+    MODEL_ARCH.BERT:           "bert",
+    MODEL_ARCH.BLOOM:          "bloom",
+}
+
+TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
+    MODEL_TENSOR.TOKEN_EMBD:      "token_embd",
+    MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm",
+    MODEL_TENSOR.TOKEN_TYPES:     "token_types",
+    MODEL_TENSOR.POS_EMBD:        "position_embd",
+    MODEL_TENSOR.OUTPUT_NORM:     "output_norm",
+    MODEL_TENSOR.OUTPUT:          "output",
+    MODEL_TENSOR.ROPE_FREQS:      "rope_freqs",
+    MODEL_TENSOR.ATTN_NORM:       "blk.{bid}.attn_norm",
+    MODEL_TENSOR.ATTN_NORM_2:     "blk.{bid}.attn_norm_2",
+    MODEL_TENSOR.ATTN_QKV:        "blk.{bid}.attn_qkv",
+    MODEL_TENSOR.ATTN_Q:          "blk.{bid}.attn_q",
+    MODEL_TENSOR.ATTN_K:          "blk.{bid}.attn_k",
+    MODEL_TENSOR.ATTN_V:          "blk.{bid}.attn_v",
+    MODEL_TENSOR.ATTN_OUT:        "blk.{bid}.attn_output",
+    MODEL_TENSOR.ATTN_ROT_EMBD:   "blk.{bid}.attn_rot_embd",
+    MODEL_TENSOR.ATTN_Q_NORM:     "blk.{bid}.attn_q_norm",
+    MODEL_TENSOR.ATTN_K_NORM:     "blk.{bid}.attn_k_norm",
+    MODEL_TENSOR.FFN_NORM:        "blk.{bid}.ffn_norm",
+    MODEL_TENSOR.FFN_GATE:        "blk.{bid}.ffn_gate",
+    MODEL_TENSOR.FFN_DOWN:        "blk.{bid}.ffn_down",
+    MODEL_TENSOR.FFN_UP:          "blk.{bid}.ffn_up",
+}
+
+MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
+    MODEL_ARCH.LLAMA: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
+    MODEL_ARCH.GPTNEOX: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
+    MODEL_ARCH.FALCON: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_NORM_2,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
+    MODEL_ARCH.BAICHUAN: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
+    MODEL_ARCH.STARCODER: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.POS_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
+    MODEL_ARCH.BERT: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.TOKEN_TYPES,
+        MODEL_TENSOR.POS_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
+    MODEL_ARCH.MPT: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
+    MODEL_ARCH.GPTJ: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
+    MODEL_ARCH.PERSIMMON: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.ATTN_Q_NORM,
+        MODEL_TENSOR.ATTN_K_NORM,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+    ],
+    MODEL_ARCH.REFACT: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
+    MODEL_ARCH.BLOOM: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.TOKEN_EMBD_NORM,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
+    MODEL_ARCH.GPT2: [
+        # TODO
+    ],
+    # TODO
+}
+
+# tensors that will not be serialized
+MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
+    MODEL_ARCH.LLAMA: [
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+    ],
+    MODEL_ARCH.BAICHUAN: [
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+    ],
+    MODEL_ARCH.PERSIMMON: [
+        MODEL_TENSOR.ROPE_FREQS,
+    ],
+}
+
+#
+# types
+#
+
+
+class TokenType(IntEnum):
+    NORMAL       = 1
+    UNKNOWN      = 2
+    CONTROL      = 3
+    USER_DEFINED = 4
+    UNUSED       = 5
+    BYTE         = 6
+
+
+class RopeScalingType(Enum):
+    NONE   = 'none'
+    LINEAR = 'linear'
+    YARN   = 'yarn'
+
+
+class GGMLQuantizationType(IntEnum):
+    F32  = 0
+    F16  = 1
+    Q4_0 = 2
+    Q4_1 = 3
+    Q5_0 = 6
+    Q5_1 = 7
+    Q8_0 = 8
+    Q8_1 = 9
+    Q2_K = 10
+    Q3_K = 11
+    Q4_K = 12
+    Q5_K = 13
+    Q6_K = 14
+    Q8_K = 15
+
+
+class GGUFEndian(IntEnum):
+    LITTLE = 0
+    BIG = 1
+
+
+class GGUFValueType(IntEnum):
+    UINT8   = 0
+    INT8    = 1
+    UINT16  = 2
+    INT16   = 3
+    UINT32  = 4
+    INT32   = 5
+    FLOAT32 = 6
+    BOOL    = 7
+    STRING  = 8
+    ARRAY   = 9
+    UINT64  = 10
+    INT64   = 11
+    FLOAT64 = 12
+
+    @staticmethod
+    def get_type(val: Any) -> GGUFValueType:
+        if isinstance(val, (str, bytes, bytearray)):
+            return GGUFValueType.STRING
+        elif isinstance(val, list):
+            return GGUFValueType.ARRAY
+        elif isinstance(val, float):
+            return GGUFValueType.FLOAT32
+        elif isinstance(val, bool):
+            return GGUFValueType.BOOL
+        elif isinstance(val, int):
+            return GGUFValueType.INT32
+        # TODO: need help with 64-bit types in Python
+        else:
+            print("Unknown type:", type(val))
+            sys.exit()
+
+
+# Note: Does not support GGML_QKK_64
+QK_K = 256
+# Items here are (block size, type size)
+GGML_QUANT_SIZES = {
+    GGMLQuantizationType.F32:  (1, 4),
+    GGMLQuantizationType.F16:  (1, 2),
+    GGMLQuantizationType.Q4_0: (32, 2 + 16),
+    GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16),
+    GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16),
+    GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16),
+    GGMLQuantizationType.Q8_0: (32, 2 + 32),
+    GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32),
+    GGMLQuantizationType.Q2_K: (256, 2 + 2 + QK_K // 16 + QK_K // 4),
+    GGMLQuantizationType.Q3_K: (256, 2 + QK_K // 4 + QK_K // 8 + 12),
+    GGMLQuantizationType.Q4_K: (256, 2 + 2 + QK_K // 2 + 12),
+    GGMLQuantizationType.Q5_K: (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12),
+    GGMLQuantizationType.Q6_K: (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16),
+    GGMLQuantizationType.Q8_K: (256, 4 + QK_K + QK_K // 8),
+}
+
+
+# Aliases for backward compatibility.
+
+# general
+KEY_GENERAL_ARCHITECTURE         = Keys.General.ARCHITECTURE
+KEY_GENERAL_QUANTIZATION_VERSION = Keys.General.QUANTIZATION_VERSION
+KEY_GENERAL_ALIGNMENT            = Keys.General.ALIGNMENT
+KEY_GENERAL_NAME                 = Keys.General.NAME
+KEY_GENERAL_AUTHOR               = Keys.General.AUTHOR
+KEY_GENERAL_URL                  = Keys.General.URL
+KEY_GENERAL_DESCRIPTION          = Keys.General.DESCRIPTION
+KEY_GENERAL_LICENSE              = Keys.General.LICENSE
+KEY_GENERAL_SOURCE_URL           = Keys.General.SOURCE_URL
+KEY_GENERAL_SOURCE_HF_REPO       = Keys.General.SOURCE_HF_REPO
+KEY_GENERAL_FILE_TYPE            = Keys.General.FILE_TYPE
+
+# LLM
+KEY_CONTEXT_LENGTH        = Keys.LLM.CONTEXT_LENGTH
+KEY_EMBEDDING_LENGTH      = Keys.LLM.EMBEDDING_LENGTH
+KEY_BLOCK_COUNT           = Keys.LLM.BLOCK_COUNT
+KEY_FEED_FORWARD_LENGTH   = Keys.LLM.FEED_FORWARD_LENGTH
+KEY_USE_PARALLEL_RESIDUAL = Keys.LLM.USE_PARALLEL_RESIDUAL
+KEY_TENSOR_DATA_LAYOUT    = Keys.LLM.TENSOR_DATA_LAYOUT
+
+# attention
+KEY_ATTENTION_HEAD_COUNT        = Keys.Attention.HEAD_COUNT
+KEY_ATTENTION_HEAD_COUNT_KV     = Keys.Attention.HEAD_COUNT_KV
+KEY_ATTENTION_MAX_ALIBI_BIAS    = Keys.Attention.MAX_ALIBI_BIAS
+KEY_ATTENTION_CLAMP_KQV         = Keys.Attention.CLAMP_KQV
+KEY_ATTENTION_LAYERNORM_EPS     = Keys.Attention.LAYERNORM_EPS
+KEY_ATTENTION_LAYERNORM_RMS_EPS = Keys.Attention.LAYERNORM_RMS_EPS
+
+# RoPE
+KEY_ROPE_DIMENSION_COUNT      = Keys.Rope.DIMENSION_COUNT
+KEY_ROPE_FREQ_BASE            = Keys.Rope.FREQ_BASE
+KEY_ROPE_SCALING_TYPE         = Keys.Rope.SCALING_TYPE
+KEY_ROPE_SCALING_FACTOR       = Keys.Rope.SCALING_FACTOR
+KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN
+KEY_ROPE_SCALING_FINETUNED    = Keys.Rope.SCALING_FINETUNED
+
+# tokenization
+KEY_TOKENIZER_MODEL      = Keys.Tokenizer.MODEL
+KEY_TOKENIZER_LIST       = Keys.Tokenizer.LIST
+KEY_TOKENIZER_TOKEN_TYPE = Keys.Tokenizer.TOKEN_TYPE
+KEY_TOKENIZER_SCORES     = Keys.Tokenizer.SCORES
+KEY_TOKENIZER_MERGES     = Keys.Tokenizer.MERGES
+KEY_TOKENIZER_BOS_ID     = Keys.Tokenizer.BOS_ID
+KEY_TOKENIZER_EOS_ID     = Keys.Tokenizer.EOS_ID
+KEY_TOKENIZER_UNK_ID     = Keys.Tokenizer.UNK_ID
+KEY_TOKENIZER_SEP_ID     = Keys.Tokenizer.SEP_ID
+KEY_TOKENIZER_PAD_ID     = Keys.Tokenizer.PAD_ID
+KEY_TOKENIZER_HF_JSON    = Keys.Tokenizer.HF_JSON
+KEY_TOKENIZER_RWKV       = Keys.Tokenizer.RWKV
index 7e495cb19638d1c952fbad34b7217e36110e1553..651a81eb828248728f854c85c1a437b52892f275 100644 (file)
-#!/usr/bin/env python3
-from __future__ import annotations
+# This file left for compatibility. If you want to use the GGUF API from Python
+# then don't import gguf/gguf.py directly. If you're looking for examples, see the
+# examples/ directory for gguf-py
 
-import json
-import os
-import shutil
-import struct
+import importlib
 import sys
-import tempfile
-from enum import Enum, IntEnum, auto
-from io import BufferedWriter
 from pathlib import Path
-from typing import IO, Any, BinaryIO, Callable, Sequence
 
-import numpy as np
+sys.path.insert(0, str(Path(__file__).parent.parent))
 
-#
-# constants
-#
+# Compatibility for people trying to import gguf/gguf.py directly instead of as a package.
+importlib.invalidate_caches()
+import gguf  # noqa: E402
 
-GGUF_MAGIC             = 0x46554747
-GGUF_VERSION           = 3
-GGUF_DEFAULT_ALIGNMENT = 32
-
-
-# general
-KEY_GENERAL_ARCHITECTURE         = "general.architecture"
-KEY_GENERAL_QUANTIZATION_VERSION = "general.quantization_version"
-KEY_GENERAL_ALIGNMENT            = "general.alignment"
-KEY_GENERAL_NAME                 = "general.name"
-KEY_GENERAL_AUTHOR               = "general.author"
-KEY_GENERAL_URL                  = "general.url"
-KEY_GENERAL_DESCRIPTION          = "general.description"
-KEY_GENERAL_LICENSE              = "general.license"
-KEY_GENERAL_SOURCE_URL           = "general.source.url"
-KEY_GENERAL_SOURCE_HF_REPO       = "general.source.huggingface.repository"
-KEY_GENERAL_FILE_TYPE            = "general.file_type"
-
-# LLM
-KEY_CONTEXT_LENGTH        = "{arch}.context_length"
-KEY_EMBEDDING_LENGTH      = "{arch}.embedding_length"
-KEY_BLOCK_COUNT           = "{arch}.block_count"
-KEY_FEED_FORWARD_LENGTH   = "{arch}.feed_forward_length"
-KEY_USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
-KEY_TENSOR_DATA_LAYOUT    = "{arch}.tensor_data_layout"
-
-# attention
-KEY_ATTENTION_HEAD_COUNT        = "{arch}.attention.head_count"
-KEY_ATTENTION_HEAD_COUNT_KV     = "{arch}.attention.head_count_kv"
-KEY_ATTENTION_MAX_ALIBI_BIAS    = "{arch}.attention.max_alibi_bias"
-KEY_ATTENTION_CLAMP_KQV         = "{arch}.attention.clamp_kqv"
-KEY_ATTENTION_LAYERNORM_EPS     = "{arch}.attention.layer_norm_epsilon"
-KEY_ATTENTION_LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
-
-# RoPE
-KEY_ROPE_DIMENSION_COUNT         = "{arch}.rope.dimension_count"
-KEY_ROPE_FREQ_BASE               = "{arch}.rope.freq_base"
-KEY_ROPE_SCALING_TYPE            = "{arch}.rope.scaling.type"
-KEY_ROPE_SCALING_FACTOR          = "{arch}.rope.scaling.factor"
-KEY_ROPE_SCALING_ORIG_CTX_LEN    = "{arch}.rope.scaling.original_context_length"
-KEY_ROPE_SCALING_FINETUNED       = "{arch}.rope.scaling.finetuned"
-
-# tokenization
-KEY_TOKENIZER_MODEL      = "tokenizer.ggml.model"
-KEY_TOKENIZER_LIST       = "tokenizer.ggml.tokens"
-KEY_TOKENIZER_TOKEN_TYPE = "tokenizer.ggml.token_type"
-KEY_TOKENIZER_SCORES     = "tokenizer.ggml.scores"
-KEY_TOKENIZER_MERGES     = "tokenizer.ggml.merges"
-KEY_TOKENIZER_BOS_ID     = "tokenizer.ggml.bos_token_id"
-KEY_TOKENIZER_EOS_ID     = "tokenizer.ggml.eos_token_id"
-KEY_TOKENIZER_UNK_ID     = "tokenizer.ggml.unknown_token_id"
-KEY_TOKENIZER_SEP_ID     = "tokenizer.ggml.seperator_token_id"
-KEY_TOKENIZER_PAD_ID     = "tokenizer.ggml.padding_token_id"
-KEY_TOKENIZER_HF_JSON    = "tokenizer.huggingface.json"
-KEY_TOKENIZER_RWKV       = "tokenizer.rwkv.world"
-
-
-#
-# recommended mapping of model tensor names for storage in gguf
-#
-
-
-class MODEL_ARCH(IntEnum):
-    LLAMA         : int = auto()
-    FALCON        : int = auto()
-    BAICHUAN      : int = auto()
-    GPT2          : int = auto()
-    GPTJ          : int = auto()
-    GPTNEOX       : int = auto()
-    MPT           : int = auto()
-    STARCODER     : int = auto()
-    PERSIMMON     : int = auto()
-    REFACT        : int = auto()
-    BERT          : int = auto()
-    BLOOM         : int = auto()
-
-
-class MODEL_TENSOR(IntEnum):
-    TOKEN_EMBD      : int = auto()
-    TOKEN_EMBD_NORM : int = auto()
-    TOKEN_TYPES     : int = auto()
-    POS_EMBD        : int = auto()
-    OUTPUT          : int = auto()
-    OUTPUT_NORM     : int = auto()
-    ROPE_FREQS      : int = auto()
-    ATTN_Q          : int = auto()
-    ATTN_K          : int = auto()
-    ATTN_V          : int = auto()
-    ATTN_QKV        : int = auto()
-    ATTN_OUT        : int = auto()
-    ATTN_NORM       : int = auto()
-    ATTN_NORM_2     : int = auto()
-    ATTN_ROT_EMBD   : int = auto()
-    FFN_GATE        : int = auto()
-    FFN_DOWN        : int = auto()
-    FFN_UP          : int = auto()
-    FFN_NORM        : int = auto()
-    ATTN_Q_NORM     : int = auto()
-    ATTN_K_NORM     : int = auto()
-
-
-MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
-    MODEL_ARCH.LLAMA:          "llama",
-    MODEL_ARCH.FALCON:         "falcon",
-    MODEL_ARCH.BAICHUAN:       "baichuan",
-    MODEL_ARCH.GPT2:           "gpt2",
-    MODEL_ARCH.GPTJ:           "gptj",
-    MODEL_ARCH.GPTNEOX:        "gptneox",
-    MODEL_ARCH.MPT:            "mpt",
-    MODEL_ARCH.STARCODER:      "starcoder",
-    MODEL_ARCH.PERSIMMON:      "persimmon",
-    MODEL_ARCH.REFACT:         "refact",
-    MODEL_ARCH.BERT:           "bert",
-    MODEL_ARCH.BLOOM:          "bloom",
-}
-
-TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
-    MODEL_TENSOR.TOKEN_EMBD:      "token_embd",
-    MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm",
-    MODEL_TENSOR.TOKEN_TYPES:     "token_types",
-    MODEL_TENSOR.POS_EMBD:        "position_embd",
-    MODEL_TENSOR.OUTPUT_NORM:     "output_norm",
-    MODEL_TENSOR.OUTPUT:          "output",
-    MODEL_TENSOR.ROPE_FREQS:      "rope_freqs",
-    MODEL_TENSOR.ATTN_NORM:       "blk.{bid}.attn_norm",
-    MODEL_TENSOR.ATTN_NORM_2:     "blk.{bid}.attn_norm_2",
-    MODEL_TENSOR.ATTN_QKV:        "blk.{bid}.attn_qkv",
-    MODEL_TENSOR.ATTN_Q:          "blk.{bid}.attn_q",
-    MODEL_TENSOR.ATTN_K:          "blk.{bid}.attn_k",
-    MODEL_TENSOR.ATTN_V:          "blk.{bid}.attn_v",
-    MODEL_TENSOR.ATTN_OUT:        "blk.{bid}.attn_output",
-    MODEL_TENSOR.ATTN_ROT_EMBD:   "blk.{bid}.attn_rot_embd",
-    MODEL_TENSOR.ATTN_Q_NORM:     "blk.{bid}.attn_q_norm",
-    MODEL_TENSOR.ATTN_K_NORM:     "blk.{bid}.attn_k_norm",
-    MODEL_TENSOR.FFN_NORM:        "blk.{bid}.ffn_norm",
-    MODEL_TENSOR.FFN_GATE:        "blk.{bid}.ffn_gate",
-    MODEL_TENSOR.FFN_DOWN:        "blk.{bid}.ffn_down",
-    MODEL_TENSOR.FFN_UP:          "blk.{bid}.ffn_up",
-}
-
-MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
-    MODEL_ARCH.LLAMA: [
-        MODEL_TENSOR.TOKEN_EMBD,
-        MODEL_TENSOR.OUTPUT_NORM,
-        MODEL_TENSOR.OUTPUT,
-        MODEL_TENSOR.ROPE_FREQS,
-        MODEL_TENSOR.ATTN_NORM,
-        MODEL_TENSOR.ATTN_Q,
-        MODEL_TENSOR.ATTN_K,
-        MODEL_TENSOR.ATTN_V,
-        MODEL_TENSOR.ATTN_OUT,
-        MODEL_TENSOR.ATTN_ROT_EMBD,
-        MODEL_TENSOR.FFN_NORM,
-        MODEL_TENSOR.FFN_GATE,
-        MODEL_TENSOR.FFN_DOWN,
-        MODEL_TENSOR.FFN_UP,
-    ],
-    MODEL_ARCH.GPTNEOX: [
-        MODEL_TENSOR.TOKEN_EMBD,
-        MODEL_TENSOR.OUTPUT_NORM,
-        MODEL_TENSOR.OUTPUT,
-        MODEL_TENSOR.ATTN_NORM,
-        MODEL_TENSOR.ATTN_QKV,
-        MODEL_TENSOR.ATTN_OUT,
-        MODEL_TENSOR.FFN_NORM,
-        MODEL_TENSOR.FFN_DOWN,
-        MODEL_TENSOR.FFN_UP,
-    ],
-    MODEL_ARCH.FALCON: [
-        MODEL_TENSOR.TOKEN_EMBD,
-        MODEL_TENSOR.OUTPUT_NORM,
-        MODEL_TENSOR.OUTPUT,
-        MODEL_TENSOR.ATTN_NORM,
-        MODEL_TENSOR.ATTN_NORM_2,
-        MODEL_TENSOR.ATTN_QKV,
-        MODEL_TENSOR.ATTN_OUT,
-        MODEL_TENSOR.FFN_DOWN,
-        MODEL_TENSOR.FFN_UP,
-    ],
-    MODEL_ARCH.BAICHUAN: [
-        MODEL_TENSOR.TOKEN_EMBD,
-        MODEL_TENSOR.OUTPUT_NORM,
-        MODEL_TENSOR.OUTPUT,
-        MODEL_TENSOR.ROPE_FREQS,
-        MODEL_TENSOR.ATTN_NORM,
-        MODEL_TENSOR.ATTN_Q,
-        MODEL_TENSOR.ATTN_K,
-        MODEL_TENSOR.ATTN_V,
-        MODEL_TENSOR.ATTN_OUT,
-        MODEL_TENSOR.ATTN_ROT_EMBD,
-        MODEL_TENSOR.FFN_NORM,
-        MODEL_TENSOR.FFN_GATE,
-        MODEL_TENSOR.FFN_DOWN,
-        MODEL_TENSOR.FFN_UP,
-    ],
-    MODEL_ARCH.STARCODER: [
-        MODEL_TENSOR.TOKEN_EMBD,
-        MODEL_TENSOR.POS_EMBD,
-        MODEL_TENSOR.OUTPUT_NORM,
-        MODEL_TENSOR.OUTPUT,
-        MODEL_TENSOR.ATTN_NORM,
-        MODEL_TENSOR.ATTN_QKV,
-        MODEL_TENSOR.ATTN_OUT,
-        MODEL_TENSOR.FFN_NORM,
-        MODEL_TENSOR.FFN_DOWN,
-        MODEL_TENSOR.FFN_UP,
-    ],
-    MODEL_ARCH.BERT: [
-        MODEL_TENSOR.TOKEN_EMBD,
-        MODEL_TENSOR.TOKEN_TYPES,
-        MODEL_TENSOR.POS_EMBD,
-        MODEL_TENSOR.OUTPUT_NORM,
-        MODEL_TENSOR.ATTN_NORM,
-        MODEL_TENSOR.ATTN_Q,
-        MODEL_TENSOR.ATTN_K,
-        MODEL_TENSOR.ATTN_V,
-        MODEL_TENSOR.ATTN_OUT,
-        MODEL_TENSOR.FFN_NORM,
-        MODEL_TENSOR.FFN_DOWN,
-        MODEL_TENSOR.FFN_UP,
-    ],
-    MODEL_ARCH.MPT: [
-        MODEL_TENSOR.TOKEN_EMBD,
-        MODEL_TENSOR.OUTPUT_NORM,
-        MODEL_TENSOR.OUTPUT,
-        MODEL_TENSOR.ATTN_NORM,
-        MODEL_TENSOR.ATTN_QKV,
-        MODEL_TENSOR.ATTN_OUT,
-        MODEL_TENSOR.FFN_NORM,
-        MODEL_TENSOR.FFN_DOWN,
-        MODEL_TENSOR.FFN_UP,
-    ],
-    MODEL_ARCH.GPTJ: [
-        MODEL_TENSOR.TOKEN_EMBD,
-        MODEL_TENSOR.OUTPUT_NORM,
-        MODEL_TENSOR.OUTPUT,
-        MODEL_TENSOR.ATTN_NORM,
-        MODEL_TENSOR.ATTN_Q,
-        MODEL_TENSOR.ATTN_K,
-        MODEL_TENSOR.ATTN_V,
-        MODEL_TENSOR.ATTN_OUT,
-        MODEL_TENSOR.FFN_DOWN,
-        MODEL_TENSOR.FFN_UP,
-    ],
-    MODEL_ARCH.PERSIMMON: [
-        MODEL_TENSOR.TOKEN_EMBD,
-        MODEL_TENSOR.OUTPUT,
-        MODEL_TENSOR.OUTPUT_NORM,
-        MODEL_TENSOR.ATTN_NORM,
-        MODEL_TENSOR.ATTN_QKV,
-        MODEL_TENSOR.ATTN_OUT,
-        MODEL_TENSOR.FFN_NORM,
-        MODEL_TENSOR.FFN_DOWN,
-        MODEL_TENSOR.FFN_UP,
-        MODEL_TENSOR.ATTN_Q_NORM,
-        MODEL_TENSOR.ATTN_K_NORM,
-        MODEL_TENSOR.ATTN_ROT_EMBD,
-    ],
-    MODEL_ARCH.REFACT: [
-        MODEL_TENSOR.TOKEN_EMBD,
-        MODEL_TENSOR.OUTPUT_NORM,
-        MODEL_TENSOR.OUTPUT,
-        MODEL_TENSOR.ATTN_NORM,
-        MODEL_TENSOR.ATTN_Q,
-        MODEL_TENSOR.ATTN_K,
-        MODEL_TENSOR.ATTN_V,
-        MODEL_TENSOR.ATTN_OUT,
-        MODEL_TENSOR.FFN_NORM,
-        MODEL_TENSOR.FFN_GATE,
-        MODEL_TENSOR.FFN_DOWN,
-        MODEL_TENSOR.FFN_UP,
-    ],
-    MODEL_ARCH.BLOOM: [
-        MODEL_TENSOR.TOKEN_EMBD,
-        MODEL_TENSOR.TOKEN_EMBD_NORM,
-        MODEL_TENSOR.OUTPUT_NORM,
-        MODEL_TENSOR.OUTPUT,
-        MODEL_TENSOR.ATTN_NORM,
-        MODEL_TENSOR.ATTN_QKV,
-        MODEL_TENSOR.ATTN_OUT,
-        MODEL_TENSOR.FFN_NORM,
-        MODEL_TENSOR.FFN_DOWN,
-        MODEL_TENSOR.FFN_UP,
-    ],
-    MODEL_ARCH.GPT2: [
-        # TODO
-    ],
-    # TODO
-}
-
-# tensors that will not be serialized
-MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
-    MODEL_ARCH.LLAMA: [
-        MODEL_TENSOR.ROPE_FREQS,
-        MODEL_TENSOR.ATTN_ROT_EMBD,
-    ],
-    MODEL_ARCH.BAICHUAN: [
-        MODEL_TENSOR.ROPE_FREQS,
-        MODEL_TENSOR.ATTN_ROT_EMBD,
-    ],
-    MODEL_ARCH.PERSIMMON: [
-        MODEL_TENSOR.ROPE_FREQS,
-    ]
-}
-
-
-class TensorNameMap:
-    mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
-        # Token embeddings
-        MODEL_TENSOR.TOKEN_EMBD: (
-            "gpt_neox.embed_in",                        # gptneox
-            "transformer.wte",                          # gpt2 gpt-j mpt refact
-            "transformer.word_embeddings",              # falcon
-            "word_embeddings",                          # bloom
-            "model.embed_tokens",                       # llama-hf
-            "tok_embeddings",                           # llama-pth
-            "embeddings.word_embeddings",               # bert
-            "language_model.embedding.word_embeddings", # persimmon
-        ),
-
-        # Token type embeddings
-        MODEL_TENSOR.TOKEN_TYPES: (
-            "embeddings.token_type_embeddings",  # bert
-        ),
-
-        # Normalization of token embeddings
-        MODEL_TENSOR.TOKEN_EMBD_NORM: (
-            "word_embeddings_layernorm",  # bloom
-        ),
-
-        # Position embeddings
-        MODEL_TENSOR.POS_EMBD: (
-            "transformer.wpe",                 # gpt2
-            "embeddings.position_embeddings",  # bert
-        ),
-
-        # Output
-        MODEL_TENSOR.OUTPUT: (
-            "embed_out",                # gptneox
-            "lm_head",                  # gpt2 mpt falcon llama-hf baichuan
-            "output",                   # llama-pth bloom
-            "word_embeddings_for_head", # persimmon
-        ),
-
-        # Output norm
-        MODEL_TENSOR.OUTPUT_NORM: (
-            "gpt_neox.final_layer_norm",              # gptneox
-            "transformer.ln_f",                       # gpt2 gpt-j falcon
-            "model.norm",                             # llama-hf baichuan
-            "norm",                                   # llama-pth
-            "embeddings.LayerNorm",                   # bert
-            "transformer.norm_f",                     # mpt
-            "ln_f",                                   # refact bloom
-            "language_model.encoder.final_layernorm", # persimmon
-        ),
-
-        # Rope frequencies
-        MODEL_TENSOR.ROPE_FREQS: (
-            "rope.freqs", # llama-pth
-        ),
-    }
-
-    block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
-        # Attention norm
-        MODEL_TENSOR.ATTN_NORM: (
-            "gpt_neox.layers.{bid}.input_layernorm",               # gptneox
-            "transformer.h.{bid}.ln_1",                            # gpt2 gpt-j refact
-            "transformer.blocks.{bid}.norm_1",                     # mpt
-            "transformer.h.{bid}.input_layernorm",                 # falcon7b
-            "h.{bid}.input_layernorm",                             # bloom
-            "transformer.h.{bid}.ln_mlp",                          # falcon40b
-            "model.layers.{bid}.input_layernorm",                  # llama-hf
-            "layers.{bid}.attention_norm",                         # llama-pth
-            "encoder.layer.{bid}.attention.output.LayerNorm",      # bert
-            "language_model.encoder.layers.{bid}.input_layernorm", # persimmon
-            "model.layers.{bid}.ln1",                              # yi
-        ),
-
-        # Attention norm 2
-        MODEL_TENSOR.ATTN_NORM_2: (
-            "transformer.h.{bid}.ln_attn", # falcon40b
-        ),
-
-        # Attention query-key-value
-        MODEL_TENSOR.ATTN_QKV: (
-            "gpt_neox.layers.{bid}.attention.query_key_value",                    # gptneox
-            "transformer.h.{bid}.attn.c_attn",                                    # gpt2
-            "transformer.blocks.{bid}.attn.Wqkv",                                 # mpt
-            "transformer.h.{bid}.self_attention.query_key_value",                 # falcon
-            "h.{bid}.self_attention.query_key_value",                             # bloom
-            "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
-        ),
-
-        # Attention query
-        MODEL_TENSOR.ATTN_Q: (
-            "model.layers.{bid}.self_attn.q_proj",       # llama-hf
-            "layers.{bid}.attention.wq",                 # llama-pth
-            "encoder.layer.{bid}.attention.self.query",  # bert
-            "transformer.h.{bid}.attn.q_proj",           # gpt-j
-        ),
-
-        # Attention key
-        MODEL_TENSOR.ATTN_K: (
-            "model.layers.{bid}.self_attn.k_proj",     # llama-hf
-            "layers.{bid}.attention.wk",               # llama-pth
-            "encoder.layer.{bid}.attention.self.key",  # bert
-            "transformer.h.{bid}.attn.k_proj",         # gpt-j
-        ),
-
-        # Attention value
-        MODEL_TENSOR.ATTN_V: (
-            "model.layers.{bid}.self_attn.v_proj",       # llama-hf
-            "layers.{bid}.attention.wv",                 # llama-pth
-            "encoder.layer.{bid}.attention.self.value",  # bert
-            "transformer.h.{bid}.attn.v_proj",           # gpt-j
-        ),
-
-        # Attention output
-        MODEL_TENSOR.ATTN_OUT: (
-            "gpt_neox.layers.{bid}.attention.dense",                   # gptneox
-            "transformer.h.{bid}.attn.c_proj",                         # gpt2 refact
-            "transformer.blocks.{bid}.attn.out_proj",                  # mpt
-            "transformer.h.{bid}.self_attention.dense",                # falcon
-            "h.{bid}.self_attention.dense",                            # bloom
-            "model.layers.{bid}.self_attn.o_proj",                     # llama-hf
-            "layers.{bid}.attention.wo",                               # llama-pth
-            "encoder.layer.{bid}.attention.output.dense",              # bert
-            "transformer.h.{bid}.attn.out_proj",                       # gpt-j
-            "language_model.encoder.layers.{bid}.self_attention.dense" # persimmon
-        ),
-
-        # Rotary embeddings
-        MODEL_TENSOR.ATTN_ROT_EMBD: (
-            "model.layers.{bid}.self_attn.rotary_emb.inv_freq",  # llama-hf
-            "layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
-        ),
-
-        # Feed-forward norm
-        MODEL_TENSOR.FFN_NORM: (
-            "gpt_neox.layers.{bid}.post_attention_layernorm",               # gptneox
-            "transformer.h.{bid}.ln_2",                                     # gpt2 refact
-            "h.{bid}.post_attention_layernorm",                             # bloom
-            "transformer.blocks.{bid}.norm_2",                              # mpt
-            "model.layers.{bid}.post_attention_layernorm",                  # llama-hf
-            "layers.{bid}.ffn_norm",                                        # llama-pth
-            "encoder.layer.{bid}.output.LayerNorm",                         # bert
-            "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
-            "model.layers.{bid}.ln2",                                       # yi
-        ),
-
-        # Feed-forward up
-        MODEL_TENSOR.FFN_UP: (
-            "gpt_neox.layers.{bid}.mlp.dense_h_to_4h",               # gptneox
-            "transformer.h.{bid}.mlp.c_fc",                          # gpt2
-            "transformer.blocks.{bid}.ffn.up_proj",                  # mpt
-            "transformer.h.{bid}.mlp.dense_h_to_4h",                 # falcon
-            "h.{bid}.mlp.dense_h_to_4h",                             # bloom
-            "model.layers.{bid}.mlp.up_proj",                        # llama-hf refact
-            "layers.{bid}.feed_forward.w3",                          # llama-pth
-            "encoder.layer.{bid}.intermediate.dense",                # bert
-            "transformer.h.{bid}.mlp.fc_in",                         # gpt-j
-            "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
-        ),
-
-        # Feed-forward gate
-        MODEL_TENSOR.FFN_GATE: (
-            "model.layers.{bid}.mlp.gate_proj", # llama-hf refact
-            "layers.{bid}.feed_forward.w1",     # llama-pth
-        ),
-
-        # Feed-forward down
-        MODEL_TENSOR.FFN_DOWN: (
-            "gpt_neox.layers.{bid}.mlp.dense_4h_to_h",               # gptneox
-            "transformer.h.{bid}.mlp.c_proj",                        # gpt2 refact
-            "transformer.blocks.{bid}.ffn.down_proj",                # mpt
-            "transformer.h.{bid}.mlp.dense_4h_to_h",                 # falcon
-            "h.{bid}.mlp.dense_4h_to_h",                             # bloom
-            "model.layers.{bid}.mlp.down_proj",                      # llama-hf
-            "layers.{bid}.feed_forward.w2",                          # llama-pth
-            "encoder.layer.{bid}.output.dense",                      # bert
-            "transformer.h.{bid}.mlp.fc_out",                        # gpt-j
-            "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
-        ),
-
-        MODEL_TENSOR.ATTN_Q_NORM: (
-            "language_model.encoder.layers.{bid}.self_attention.q_layernorm",
-        ),
-
-        MODEL_TENSOR.ATTN_K_NORM: (
-            "language_model.encoder.layers.{bid}.self_attention.k_layernorm",
-        ),
-
-        MODEL_TENSOR.ROPE_FREQS: (
-            "language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
-        )
-    }
-
-    mapping: dict[str, tuple[MODEL_TENSOR, str]]
-
-    def __init__(self, arch: MODEL_ARCH, n_blocks: int):
-        self.mapping = {}
-        for tensor, keys in self.mappings_cfg.items():
-            if tensor not in MODEL_TENSORS[arch]:
-                continue
-            tensor_name = TENSOR_NAMES[tensor]
-            self.mapping[tensor_name] = (tensor, tensor_name)
-            for key in keys:
-                self.mapping[key] = (tensor, tensor_name)
-        for bid in range(n_blocks):
-            for tensor, keys in self.block_mappings_cfg.items():
-                if tensor not in MODEL_TENSORS[arch]:
-                    continue
-                tensor_name = TENSOR_NAMES[tensor].format(bid = bid)
-                self.mapping[tensor_name] = (tensor, tensor_name)
-                for key in keys:
-                    key = key.format(bid = bid)
-                    self.mapping[key] = (tensor, tensor_name)
-
-    def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
-        result = self.mapping.get(key)
-        if result is not None:
-            return result
-        for suffix in try_suffixes:
-            if key.endswith(suffix):
-                result = self.mapping.get(key[:-len(suffix)])
-                if result is not None:
-                    return (result[0], result[1] + suffix)
-        return None
-
-    def get_name(self, key: str, try_suffixes: Sequence[str] = ()) -> str | None:
-        result = self.get_type_and_name(key, try_suffixes = try_suffixes)
-        if result is None:
-            return None
-        return result[1]
-
-    def get_type(self, key: str, try_suffixes: Sequence[str] = ()) -> MODEL_TENSOR | None:
-        result = self.get_type_and_name(key, try_suffixes = try_suffixes)
-        if result is None:
-            return None
-        return result[0]
-
-    def __getitem__(self, key: str) -> str:
-        try:
-            return self.mapping[key][1]
-        except KeyError:
-            raise KeyError(key)
-
-    def __contains__(self, key: str) -> bool:
-        return key in self.mapping
-
-    def __repr__(self) -> str:
-        return repr(self.mapping)
-
-def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> TensorNameMap:
-    return TensorNameMap(arch, n_blocks)
-
-class TokenType(IntEnum):
-    NORMAL       = 1
-    UNKNOWN      = 2
-    CONTROL      = 3
-    USER_DEFINED = 4
-    UNUSED       = 5
-    BYTE         = 6
-
-class RopeScalingType(Enum):
-    NONE   = 'none'
-    LINEAR = 'linear'
-    YARN   = 'yarn'
-
-#
-# implementation
-#
-
-
-class GGMLQuantizationType(IntEnum):
-    F32  = 0
-    F16  = 1
-    Q4_0 = 2
-    Q4_1 = 3
-    Q5_0 = 6
-    Q5_1 = 7
-    Q8_0 = 8
-    Q8_1 = 9
-    Q2_K = 10
-    Q3_K = 11
-    Q4_K = 12
-    Q5_K = 13
-    Q6_K = 14
-    Q8_K = 15
-
-class GGUFEndian(IntEnum):
-    LITTLE = 0
-    BIG = 1
-
-
-class GGUFValueType(IntEnum):
-    UINT8   = 0
-    INT8    = 1
-    UINT16  = 2
-    INT16   = 3
-    UINT32  = 4
-    INT32   = 5
-    FLOAT32 = 6
-    BOOL    = 7
-    STRING  = 8
-    ARRAY   = 9
-    UINT64  = 10
-    INT64   = 11
-    FLOAT64 = 12
-
-    @staticmethod
-    def get_type(val):
-        if isinstance(val, str) or isinstance(val, bytes) or isinstance(val, bytearray):
-            return GGUFValueType.STRING
-        elif isinstance(val, list):
-            return GGUFValueType.ARRAY
-        elif isinstance(val, float):
-            return GGUFValueType.FLOAT32
-        elif isinstance(val, bool):
-            return GGUFValueType.BOOL
-        elif isinstance(val, int):
-            return GGUFValueType.INT32
-        # TODO: need help with 64-bit types in Python
-        else:
-            print("Unknown type: "+str(type(val)))
-            sys.exit()
-
-
-class WriterState(Enum):
-    EMPTY   = auto()
-    HEADER  = auto()
-    KV_DATA = auto()
-    TI_DATA = auto()
-
-
-class GGUFWriter:
-    fout: BufferedWriter
-    temp_file: tempfile.SpooledTemporaryFile[bytes] | None
-    tensors: list[np.ndarray[Any, Any]]
-
-    @property
-    def pack_prefix(self):
-        if self.endianess==GGUFEndian.LITTLE:
-            return "<"
-        else:
-            return ">"
-
-    def __init__(self, path: os.PathLike[str] | str, arch: str, use_temp_file = True, endianess=GGUFEndian.LITTLE):
-        self.fout = open(path, "wb")
-        self.arch = arch
-        self.endianess = endianess
-        self._simple_value_packing = {
-            GGUFValueType.UINT8:   f"{self.pack_prefix}B",
-            GGUFValueType.INT8:    f"{self.pack_prefix}b",
-            GGUFValueType.UINT16:  f"{self.pack_prefix}H",
-            GGUFValueType.INT16:   f"{self.pack_prefix}h",
-            GGUFValueType.UINT32:  f"{self.pack_prefix}I",
-            GGUFValueType.INT32:   f"{self.pack_prefix}i",
-            GGUFValueType.FLOAT32: f"{self.pack_prefix}f",
-            GGUFValueType.UINT64:  f"{self.pack_prefix}Q",
-            GGUFValueType.INT64:   f"{self.pack_prefix}q",
-            GGUFValueType.FLOAT64: f"{self.pack_prefix}d",
-            GGUFValueType.BOOL:    "?" ,
-        }
-        self.offset_tensor = 0
-        self.data_alignment = GGUF_DEFAULT_ALIGNMENT
-        self.kv_data = b""
-        self.kv_data_count = 0
-        self.ti_data = b""
-        self.ti_data_count = 0
-        self.use_temp_file = use_temp_file
-        self.temp_file = None
-        self.tensors = []
-        endianess_str = "Big Endian" if self.endianess == GGUFEndian.BIG else "Little Endian"
-        print(f"This gguf file is for {endianess_str} only")
-        self.state = WriterState.EMPTY
-
-        self.add_architecture()
-
-    def write_header_to_file(self):
-        if self.state is not WriterState.EMPTY:
-            raise ValueError(f'Expected output file to be empty, got {self.state}')
-
-        self.fout.write(struct.pack("<I", GGUF_MAGIC))
-        self.fout.write(struct.pack(f"{self.pack_prefix}I", GGUF_VERSION))
-        self.fout.write(struct.pack(f"{self.pack_prefix}Q", self.ti_data_count))
-        self.fout.write(struct.pack(f"{self.pack_prefix}Q", self.kv_data_count))
-        self.flush()
-        self.state = WriterState.HEADER
-
-    def write_kv_data_to_file(self):
-        if self.state is not WriterState.HEADER:
-            raise ValueError(f'Expected output file to contain the header, got {self.state}')
-
-        self.fout.write(self.kv_data)
-        self.flush()
-        self.state = WriterState.KV_DATA
-
-    def write_ti_data_to_file(self):
-        if self.state is not WriterState.KV_DATA:
-            raise ValueError(f'Expected output file to contain KV data, got {self.state}')
-
-        self.fout.write(self.ti_data)
-        self.flush()
-        self.state = WriterState.TI_DATA
-
-    def add_key(self, key: str):
-        self.add_val(key, GGUFValueType.STRING, add_vtype=False)
-
-    def add_uint8(self, key: str, val: int):
-        self.add_key(key)
-        self.add_val(val, GGUFValueType.UINT8)
-
-    def add_int8(self, key: str, val: int):
-        self.add_key(key)
-        self.add_val(val, GGUFValueType.INT8)
-
-    def add_uint16(self, key: str, val: int):
-        self.add_key(key)
-        self.add_val(val, GGUFValueType.UINT16)
-
-    def add_int16(self, key: str, val: int):
-        self.add_key(key)
-        self.add_val(val, GGUFValueType.INT16)
-
-    def add_uint32(self, key: str, val: int):
-        self.add_key(key)
-        self.add_val(val, GGUFValueType.UINT32)
-
-    def add_int32(self, key: str, val: int):
-        self.add_key(key)
-        self.add_val(val, GGUFValueType.INT32)
-
-    def add_float32(self, key: str, val: float):
-        self.add_key(key)
-        self.add_val(val, GGUFValueType.FLOAT32)
-
-    def add_uint64(self, key: str, val: int):
-        self.add_key(key)
-        self.add_val(val, GGUFValueType.UINT64)
-
-    def add_int64(self, key: str, val: int):
-        self.add_key(key)
-        self.add_val(val, GGUFValueType.INT64)
-
-    def add_float64(self, key: str, val: float):
-        self.add_key(key)
-        self.add_val(val, GGUFValueType.FLOAT64)
-
-    def add_bool(self, key: str, val: bool):
-        self.add_key(key)
-        self.add_val(val, GGUFValueType.BOOL)
-
-    def add_string(self, key: str, val: str):
-        if len(val) == 0:
-            return
-        self.add_key(key)
-        self.add_val(val, GGUFValueType.STRING)
-
-    def add_array(self, key: str, val: Sequence[Any]):
-        if not isinstance(val, Sequence):
-            raise ValueError("Value must be a sequence for array type")
-
-        self.add_key(key)
-        self.add_val(val, GGUFValueType.ARRAY)
-
-    def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True):
-        if vtype is None:
-            vtype = GGUFValueType.get_type(val)
-
-        if add_vtype:
-            self.kv_data += struct.pack(f"{self.pack_prefix}I", vtype)
-            self.kv_data_count += 1
-
-        pack_fmt = self._simple_value_packing.get(vtype)
-        if pack_fmt is not None:
-            self.kv_data += struct.pack(pack_fmt, val)
-        elif vtype == GGUFValueType.STRING:
-            encoded_val = val.encode("utf8") if isinstance(val, str) else val
-            self.kv_data += struct.pack(f"{self.pack_prefix}Q", len(encoded_val))
-            self.kv_data += encoded_val
-        elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and len(val) > 0:
-            ltype = GGUFValueType.get_type(val[0])
-            if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
-                raise ValueError("All items in a GGUF array should be of the same type")
-            self.kv_data += struct.pack(f"{self.pack_prefix}I", ltype)
-            self.kv_data += struct.pack(f"{self.pack_prefix}Q", len(val))
-            for item in val:
-                self.add_val(item, add_vtype=False)
-        else:
-            raise ValueError("Invalid GGUF metadata value type or value")
-
-    @staticmethod
-    def ggml_pad(x: int, n: int) -> int:
-        return ((x + n - 1) // n) * n
-
-    def add_tensor_info(self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32], tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None):
-        if self.state is not WriterState.EMPTY:
-            raise ValueError(f'Expected output file to be empty, got {self.state}')
-
-        assert raw_dtype is not None or tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now"
-
-        encoded_name = name.encode("utf8")
-        self.ti_data += struct.pack(f"{self.pack_prefix}Q", len(encoded_name))
-        self.ti_data += encoded_name
-        n_dims = len(tensor_shape)
-        self.ti_data += struct.pack(f"{self.pack_prefix}I", n_dims)
-        for i in range(n_dims):
-            self.ti_data += struct.pack(f"{self.pack_prefix}Q", tensor_shape[n_dims - 1 - i])
-        if raw_dtype is None:
-            dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16
-        else:
-            dtype = raw_dtype
-        self.ti_data += struct.pack(f"{self.pack_prefix}I", dtype)
-        self.ti_data += struct.pack(f"{self.pack_prefix}Q", self.offset_tensor)
-        self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
-        self.ti_data_count += 1
-
-    def add_tensor(self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, raw_dtype: GGMLQuantizationType | None = None):
-        if self.endianess == GGUFEndian.BIG:
-            tensor.byteswap(inplace=True)
-        if self.use_temp_file and self.temp_file is None:
-            fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024)
-            fp.seek(0)
-            self.temp_file = fp
-
-        shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
-        self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype)
-
-        if self.temp_file is None:
-            self.tensors.append(tensor)
-            return
-
-        tensor.tofile(self.temp_file)
-        self.write_padding(self.temp_file, tensor.nbytes)
-
-    def write_padding(self, fp: IO[bytes], n: int, align: int | None = None):
-        pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
-        if pad != 0:
-            fp.write(bytes([0] * pad))
-
-    def write_tensor_data(self, tensor: np.ndarray[Any, Any]):
-        if self.state is not WriterState.TI_DATA:
-            raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
-
-        if self.endianess==GGUFEndian.BIG:
-            tensor.byteswap(inplace=True)
-        self.write_padding(self.fout, self.fout.tell())
-        tensor.tofile(self.fout)
-        self.write_padding(self.fout, tensor.nbytes)
-
-    def write_tensors_to_file(self):
-        self.write_ti_data_to_file()
-
-        self.write_padding(self.fout, self.fout.tell())
-
-        if self.temp_file is None:
-            while True:
-                try:
-                    tensor = self.tensors.pop(0)
-                except IndexError:
-                    break
-                tensor.tofile(self.fout)
-                self.write_padding(self.fout, tensor.nbytes)
-            return
-
-        self.temp_file.seek(0)
-
-        shutil.copyfileobj(self.temp_file, self.fout)
-        self.flush()
-        self.temp_file.close()
-
-    def flush(self):
-        self.fout.flush()
-
-    def close(self):
-        self.fout.close()
-
-    def add_architecture(self):
-        self.add_string(KEY_GENERAL_ARCHITECTURE, self.arch)
-
-    def add_author(self, author: str):
-        self.add_string(KEY_GENERAL_AUTHOR, author)
-
-    def add_tensor_data_layout(self, layout: str):
-        self.add_string(KEY_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
-
-    def add_url(self, url: str):
-        self.add_string(KEY_GENERAL_URL, url)
-
-    def add_description(self, description: str):
-        self.add_string(KEY_GENERAL_DESCRIPTION, description)
-
-    def add_source_url(self, url: str):
-        self.add_string(KEY_GENERAL_SOURCE_URL, url)
-
-    def add_source_hf_repo(self, repo: str):
-        self.add_string(KEY_GENERAL_SOURCE_HF_REPO, repo)
-
-    def add_file_type(self, ftype: int):
-        self.add_uint32(KEY_GENERAL_FILE_TYPE, ftype)
-
-    def add_name(self, name: str):
-        self.add_string(KEY_GENERAL_NAME, name)
-
-    def add_quantization_version(self, quantization_version: GGMLQuantizationType):
-        self.add_uint32(
-            KEY_GENERAL_QUANTIZATION_VERSION, quantization_version)
-
-    def add_custom_alignment(self, alignment: int):
-        self.data_alignment = alignment
-        self.add_uint32(KEY_GENERAL_ALIGNMENT, alignment)
-
-    def add_context_length(self, length: int):
-        self.add_uint32(
-            KEY_CONTEXT_LENGTH.format(arch=self.arch), length)
-
-    def add_embedding_length(self, length: int):
-        self.add_uint32(
-            KEY_EMBEDDING_LENGTH.format(arch=self.arch), length)
-
-    def add_block_count(self, length: int):
-        self.add_uint32(
-            KEY_BLOCK_COUNT.format(arch=self.arch), length)
-
-    def add_feed_forward_length(self, length: int):
-        self.add_uint32(
-            KEY_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
-
-    def add_parallel_residual(self, use: bool):
-        self.add_bool(
-            KEY_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
-
-    def add_head_count(self, count: int):
-        self.add_uint32(
-            KEY_ATTENTION_HEAD_COUNT.format(arch=self.arch), count)
-
-    def add_head_count_kv(self, count: int):
-        self.add_uint32(
-            KEY_ATTENTION_HEAD_COUNT_KV.format(arch=self.arch), count)
-
-    def add_max_alibi_bias(self, bias: float):
-        self.add_float32(
-            KEY_ATTENTION_MAX_ALIBI_BIAS.format(arch=self.arch), bias)
-
-    def add_clamp_kqv(self, value: float):
-        self.add_float32(
-            KEY_ATTENTION_CLAMP_KQV.format(arch=self.arch), value)
-
-    def add_layer_norm_eps(self, value: float):
-        self.add_float32(
-            KEY_ATTENTION_LAYERNORM_EPS.format(arch=self.arch), value)
-
-    def add_layer_norm_rms_eps(self, value: float):
-        self.add_float32(
-            KEY_ATTENTION_LAYERNORM_RMS_EPS.format(arch=self.arch), value)
-
-    def add_rope_dimension_count(self, count: int):
-        self.add_uint32(
-            KEY_ROPE_DIMENSION_COUNT.format(arch=self.arch), count)
-
-    def add_rope_freq_base(self, value: float):
-        self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value)
-
-    def add_rope_scaling_type(self, value: RopeScalingType):
-        self.add_string(KEY_ROPE_SCALING_TYPE.format(arch=self.arch), value.value)
-
-    def add_rope_scaling_factor(self, value: float):
-        self.add_float32(KEY_ROPE_SCALING_FACTOR.format(arch=self.arch), value)
-
-    def add_rope_scaling_orig_ctx_len(self, value: int):
-        self.add_uint32(KEY_ROPE_SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)
-
-    def add_rope_scaling_finetuned(self, value: bool):
-        self.add_bool(KEY_ROPE_SCALING_FINETUNED.format(arch=self.arch), value)
-
-    def add_tokenizer_model(self, model: str):
-        self.add_string(KEY_TOKENIZER_MODEL, model)
-
-    def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]):
-        self.add_array(KEY_TOKENIZER_LIST, tokens)
-
-    def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]):
-        self.add_array(KEY_TOKENIZER_MERGES, merges)
-
-    def add_token_types(self, types: Sequence[TokenType] | Sequence[int]):
-        self.add_array(KEY_TOKENIZER_TOKEN_TYPE, types)
-
-    def add_token_scores(self, scores: Sequence[float]):
-        self.add_array(KEY_TOKENIZER_SCORES, scores)
-
-    def add_bos_token_id(self, id: int):
-        self.add_uint32(KEY_TOKENIZER_BOS_ID, id)
-
-    def add_eos_token_id(self, id: int):
-        self.add_uint32(KEY_TOKENIZER_EOS_ID, id)
-
-    def add_unk_token_id(self, id: int):
-        self.add_uint32(KEY_TOKENIZER_UNK_ID, id)
-
-    def add_sep_token_id(self, id: int):
-        self.add_uint32(KEY_TOKENIZER_SEP_ID, id)
-
-    def add_pad_token_id(self, id: int):
-        self.add_uint32(KEY_TOKENIZER_PAD_ID, id)
-
-
-class SpecialVocab:
-    merges: list[str]
-    special_token_ids: dict[str, int]
-
-    def __init__(
-        self, path: str | os.PathLike[str], load_merges: bool = False,
-        special_token_types: tuple[str, ...] | None = None,
-        n_vocab: int | None = None,
-    ):
-        self.special_token_ids = {}
-        self.n_vocab = n_vocab
-        self.load_merges = load_merges
-        self.merges = []
-        if special_token_types is not None:
-            self.special_token_types = special_token_types
-        else:
-            self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad')
-        self._load(Path(path))
-
-    def _load(self, path: Path) -> None:
-        if not self._try_load_from_tokenizer_json(path):
-            self._try_load_from_config_json(path)
-
-    def _set_special_token(self, typ: str, tid: Any):
-        if not isinstance(tid, int) or tid < 0:
-            return
-        if self.n_vocab is None or tid < self.n_vocab:
-            self.special_token_ids[typ] = tid
-            return
-        print(f'gguf: WARNING: Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping',
-            file = sys.stderr)
-
-
-    def _try_load_from_tokenizer_json(self, path: Path) -> bool:
-        tokenizer_file = path / 'tokenizer.json'
-        if not tokenizer_file.is_file():
-            return False
-        with open(tokenizer_file, encoding = 'utf-8') as f:
-            tokenizer = json.load(f)
-        if self.load_merges:
-            merges = tokenizer.get('model', {}).get('merges')
-            if isinstance(merges, list) and len(merges) > 0 and isinstance(merges[0], str):
-                self.merges = merges
-        tokenizer_config_file = path / 'tokenizer_config.json'
-        added_tokens = tokenizer.get('added_tokens')
-        if added_tokens is None or not tokenizer_config_file.is_file():
-            return True
-        with open(tokenizer_config_file, encoding = 'utf-8') as f:
-            tokenizer_config = json.load(f)
-        for typ in self.special_token_types:
-            entry = tokenizer_config.get(f'{typ}_token')
-            if isinstance(entry, str):
-                tc_content = entry
-            elif isinstance(entry, dict):
-                entry_content = entry.get('content')
-                if not isinstance(entry_content, str):
-                    continue
-                tc_content = entry_content
-            else:
-                continue
-            # We only need the first match here.
-            maybe_token_id = next((
-                atok.get('id') for atok in added_tokens
-                if atok.get('content') == tc_content), None)
-            self._set_special_token(typ, maybe_token_id)
-        return True
-
-    def _try_load_from_config_json(self, path: Path) -> bool:
-        config_file = path / 'config.json'
-        if not config_file.is_file():
-            return False
-        with open(config_file, encoding = 'utf-8') as f:
-            config = json.load(f)
-        for typ in self.special_token_types:
-            self._set_special_token(typ, config.get(f'{typ}_token_id'))
-        return True
-
-    def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
-        if len(self.merges) > 0:
-            if not quiet:
-                print(f'gguf: Adding {len(self.merges)} merge(s).')
-            gw.add_token_merges(self.merges)
-        for typ, tokid in self.special_token_ids.items():
-            handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
-            if handler is None:
-                print(f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping', file = sys.stderr)
-                continue
-            if not quiet:
-                print(f'gguf: Setting special token type {typ} to {tokid}')
-            handler(tokid)
-
-    def __repr__(self) -> str:
-        return f'<SpecialVocab with {len(self.merges)} merges and special tokens {self.special_token_ids or "unset"}>'
-
-
-# Example usage:
-if __name__ == "__main__":
-    # Example usage with a file
-    gguf_writer = GGUFWriter("example.gguf", "llama")
-
-    gguf_writer.add_architecture()
-    gguf_writer.add_block_count(12)
-    gguf_writer.add_uint32("answer", 42)  # Write a 32-bit integer
-    gguf_writer.add_float32("answer_in_float", 42.0)  # Write a 32-bit float
-    gguf_writer.add_custom_alignment(64)
-
-    tensor1 = np.ones((32,), dtype=np.float32) * 100.0
-    tensor2 = np.ones((64,), dtype=np.float32) * 101.0
-    tensor3 = np.ones((96,), dtype=np.float32) * 102.0
-
-    gguf_writer.add_tensor("tensor1", tensor1)
-    gguf_writer.add_tensor("tensor2", tensor2)
-    gguf_writer.add_tensor("tensor3", tensor3)
-
-    gguf_writer.write_header_to_file()
-    gguf_writer.write_kv_data_to_file()
-    gguf_writer.write_tensors_to_file()
-
-    gguf_writer.close()
+importlib.reload(gguf)
diff --git a/gguf-py/gguf/gguf_reader.py b/gguf-py/gguf/gguf_reader.py
new file mode 100644 (file)
index 0000000..8682765
--- /dev/null
@@ -0,0 +1,264 @@
+#
+# GGUF file reading/modification support. For API usage information,
+# please see the files scripts/ for some fairly simple examples.
+#
+from __future__ import annotations
+
+import os
+from collections import OrderedDict
+from typing import Any, Literal, NamedTuple, TypeVar, Union
+
+import numpy as np
+import numpy.typing as npt
+
+if __name__ == "__main__":
+    import sys
+    from pathlib import Path
+
+    # Allow running file in package as a script.
+    sys.path.insert(0, str(Path(__file__).parent.parent))
+
+from gguf.constants import (
+    GGML_QUANT_SIZES,
+    GGUF_DEFAULT_ALIGNMENT,
+    GGUF_MAGIC,
+    GGUF_VERSION,
+    GGMLQuantizationType,
+    GGUFValueType,
+)
+
+
+READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION]
+
+
+class ReaderField(NamedTuple):
+    # Offset to start of this field.
+    offset: int
+
+    # Name of the field (not necessarily from file data).
+    name: str
+
+    # Data parts. Some types have multiple components, such as strings
+    # that consist of a length followed by the string data.
+    parts: list[npt.NDArray[Any]] = []
+
+    # Indexes into parts that we can call the actual data. For example
+    # an array of strings will be populated with indexes to the actual
+    # string data.
+    data: list[int] = [-1]
+
+    types: list[GGUFValueType] = []
+
+
+class ReaderTensor(NamedTuple):
+    name: str
+    tensor_type: GGMLQuantizationType
+    shape: npt.NDArray[np.uint32]
+    n_elements: int
+    n_bytes: int
+    data_offset: int
+    data: npt.NDArray[Any]
+    field: ReaderField
+
+
+class GGUFReader:
+    # I - same as host, S - swapped
+    byte_order: Literal['I' | 'S'] = 'I'
+    alignment: int = GGUF_DEFAULT_ALIGNMENT
+
+    # Note: Internal helper, API may change.
+    gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = {
+        GGUFValueType.UINT8:   np.uint8,
+        GGUFValueType.INT8:    np.int8,
+        GGUFValueType.UINT16:  np.uint16,
+        GGUFValueType.INT16:   np.int16,
+        GGUFValueType.UINT32:  np.uint32,
+        GGUFValueType.INT32:   np.int32,
+        GGUFValueType.FLOAT32: np.float32,
+        GGUFValueType.UINT64:  np.uint64,
+        GGUFValueType.INT64:   np.int64,
+        GGUFValueType.FLOAT64: np.float64,
+        GGUFValueType.BOOL:    np.bool_,
+    }
+
+    def __init__(self, path: os.PathLike[str] | str, mode: Literal['r' | 'r+' | 'c'] = 'r'):
+        self.data = np.memmap(path, mode = mode)
+        offs = 0
+        if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
+            raise ValueError('GGUF magic invalid')
+        offs += 4
+        temp_version = self._get(offs, np.uint32)
+        if temp_version[0] & 65535 == 0:
+            # If we get 0 here that means it's (probably) a GGUF file created for
+            # the opposite byte order of the machine this script is running on.
+            self.byte_order = 'S'
+            temp_version = temp_version.newbyteorder(self.byte_order)
+        version = temp_version[0]
+        if version not in READER_SUPPORTED_VERSIONS:
+            raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
+        self.fields: OrderedDict[str, ReaderField] = OrderedDict()
+        self.tensors: list[ReaderTensor] = []
+        offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
+        temp_counts = self._get(offs, np.uint64, 2)
+        offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64]))
+        offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64]))
+        tensor_count, kv_count = temp_counts
+        offs = self._build_fields(offs, kv_count)
+        offs, tensors_fields = self._build_tensors_fields(offs, tensor_count)
+        new_align = self.fields.get('general.alignment')
+        if new_align is not None:
+            if new_align.types != [GGUFValueType.UINT64]:
+                raise ValueError('Bad type for general.alignment field')
+            self.alignment = new_align.parts[-1][0]
+        padding = offs % self.alignment
+        if padding != 0:
+            offs += self.alignment - padding
+        self._build_tensors(offs, tensors_fields)
+
+    _DT = TypeVar('_DT', bound = npt.DTypeLike)
+
+    # Fetch a key/value metadata field by key.
+    def get_field(self, key: str) -> Union[ReaderField, None]:
+        return self.fields.get(key, None)
+
+    # Fetch a tensor from the list by index.
+    def get_tensor(self, idx: int) -> ReaderTensor:
+        return self.tensors[idx]
+
+    def _get(
+        self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I' | 'S' | '<'] = None,
+    ) -> npt.NDArray[Any]:
+        count = int(count)
+        itemsize = int(np.empty([], dtype = dtype).itemsize)
+        end_offs = offset + itemsize * count
+        return (
+            self.data[offset:end_offs]
+            .view(dtype = dtype)[:count]
+            .newbyteorder(override_order or self.byte_order)
+        )
+
+    def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
+        if field.name in self.fields:
+            raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
+        self.fields[field.name] = field
+        return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
+
+    def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
+        slen = self._get(offset, np.uint64)
+        return slen, self._get(offset + 8, np.uint8, slen[0])
+
+    def _get_field_parts(
+        self, orig_offs: int, raw_type: int,
+    ) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]:
+        offs = orig_offs
+        types: list[GGUFValueType] = []
+        gtype = GGUFValueType(raw_type)
+        types.append(gtype)
+        # Handle strings.
+        if gtype == GGUFValueType.STRING:
+            sparts: list[npt.NDArray[Any]] = list(self._get_str(offs))
+            size = sum(int(part.nbytes) for part in sparts)
+            return size, sparts, [1], types
+        # Check if it's a simple scalar type.
+        nptype = self.gguf_scalar_to_np.get(gtype)
+        if nptype is not None:
+            val = self._get(offs, nptype)
+            return int(val.nbytes), [val], [0], types
+        # Handle arrays.
+        if gtype == GGUFValueType.ARRAY:
+            raw_itype = self._get(offs, np.uint32)
+            offs += int(raw_itype.nbytes)
+            alen = self._get(offs, np.uint64)
+            offs += int(alen.nbytes)
+            aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
+            data_idxs: list[int] = []
+            for idx in range(alen[0]):
+                curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0])
+                if idx == 0:
+                    types += curr_types
+                idxs_offs = len(aparts)
+                aparts += curr_parts
+                data_idxs += (idx + idxs_offs for idx in curr_idxs)
+                offs += curr_size
+            return offs - orig_offs, aparts, data_idxs, types
+        # We can't deal with this one.
+        raise ValueError('Unknown/unhandled field type {gtype}')
+
+    def _get_tensor(self, orig_offs: int) -> ReaderField:
+        offs = orig_offs
+        name_len, name_data = self._get_str(offs)
+        offs += int(name_len.nbytes + name_data.nbytes)
+        n_dims = self._get(offs, np.uint32)
+        offs += int(n_dims.nbytes)
+        dims = self._get(offs, np.uint64, n_dims[0])
+        offs += int(dims.nbytes)
+        raw_dtype = self._get(offs, np.uint32)
+        offs += int(raw_dtype.nbytes)
+        offset_tensor = self._get(offs, np.uint64)
+        offs += int(offset_tensor.nbytes)
+        return ReaderField(
+            orig_offs,
+            str(bytes(name_data), encoding = 'utf-8'),
+            [name_len, name_data, n_dims, dims, raw_dtype, offset_tensor],
+            [1, 3, 4, 5],
+        )
+
+    def _build_fields(self, offs: int, count: int) -> int:
+        for _ in range(count):
+            orig_offs = offs
+            kv_klen, kv_kdata = self._get_str(offs)
+            offs += int(kv_klen.nbytes + kv_kdata.nbytes)
+            raw_kv_type = self._get(offs, np.uint32)
+            offs += int(raw_kv_type.nbytes)
+            parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
+            idxs_offs = len(parts)
+            field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
+            parts += field_parts
+            self._push_field(ReaderField(
+                orig_offs,
+                str(bytes(kv_kdata), encoding = 'utf-8'),
+                parts,
+                [idx + idxs_offs for idx in field_idxs],
+                field_types,
+            ), skip_sum = True)
+            offs += field_size
+        return offs
+
+    def _build_tensors_fields(self, offs: int, count: int) -> tuple[int, list[ReaderField]]:
+        tensor_fields = []
+        for _ in range(count):
+            field = self._get_tensor(offs)
+            offs += sum(int(part.nbytes) for part in field.parts)
+            tensor_fields.append(field)
+        return offs, tensor_fields
+
+    def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
+        tensors = []
+        for field in fields:
+            _name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
+            ggml_type = GGMLQuantizationType(raw_dtype[0])
+            n_elems = np.prod(dims)
+            block_size, type_size = GGML_QUANT_SIZES[ggml_type]
+            n_bytes = n_elems * type_size // block_size
+            data_offs = int(start_offs + offset_tensor[0])
+            item_type: npt.DTypeLike
+            if ggml_type == GGMLQuantizationType.F32:
+                item_count = n_elems
+                item_type = np.float32
+            elif ggml_type == GGMLQuantizationType.F16:
+                item_count = n_elems
+                item_type = np.float16
+            else:
+                item_count = n_bytes
+                item_type = np.uint8
+            tensors.append(ReaderTensor(
+                name = str(bytes(name_data), encoding = 'utf-8'),
+                tensor_type = ggml_type,
+                shape = dims,
+                n_elements = n_elems,
+                n_bytes = n_bytes,
+                data_offset = data_offs,
+                data = self._get(data_offs, item_type, item_count),
+                field = field,
+            ))
+        self.tensors = tensors
diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py
new file mode 100644 (file)
index 0000000..75fb697
--- /dev/null
@@ -0,0 +1,409 @@
+from __future__ import annotations
+
+import os
+import shutil
+import struct
+import tempfile
+from enum import Enum, auto
+from io import BufferedWriter
+from typing import IO, Any, Sequence
+
+import numpy as np
+
+from .constants import (
+    GGUF_DEFAULT_ALIGNMENT,
+    GGUF_MAGIC,
+    GGUF_VERSION,
+    GGMLQuantizationType,
+    GGUFEndian,
+    GGUFValueType,
+    Keys,
+    RopeScalingType,
+    TokenType,
+)
+
+
+class WriterState(Enum):
+    EMPTY   = auto()
+    HEADER  = auto()
+    KV_DATA = auto()
+    TI_DATA = auto()
+
+
+class GGUFWriter:
+    fout: BufferedWriter
+    temp_file: tempfile.SpooledTemporaryFile[bytes] | None
+    tensors: list[np.ndarray[Any, Any]]
+    _simple_value_packing = {
+        GGUFValueType.UINT8:   "B",
+        GGUFValueType.INT8:    "b",
+        GGUFValueType.UINT16:  "H",
+        GGUFValueType.INT16:   "h",
+        GGUFValueType.UINT32:  "I",
+        GGUFValueType.INT32:   "i",
+        GGUFValueType.FLOAT32: "f",
+        GGUFValueType.UINT64:  "Q",
+        GGUFValueType.INT64:   "q",
+        GGUFValueType.FLOAT64: "d",
+        GGUFValueType.BOOL:    "?",
+    }
+
+    def __init__(
+        self, path: os.PathLike[str] | str, arch: str, use_temp_file: bool = True,
+        endianess: GGUFEndian = GGUFEndian.LITTLE,
+    ):
+        self.fout = open(path, "wb")
+        self.arch = arch
+        self.endianess = endianess
+        self.offset_tensor = 0
+        self.data_alignment = GGUF_DEFAULT_ALIGNMENT
+        self.kv_data = b""
+        self.kv_data_count = 0
+        self.ti_data = b""
+        self.ti_data_count = 0
+        self.use_temp_file = use_temp_file
+        self.temp_file = None
+        self.tensors = []
+        print("gguf: This GGUF file is for {0} Endian only".format(
+            "Big" if self.endianess == GGUFEndian.BIG else "Little",
+        ))
+        self.state = WriterState.EMPTY
+
+        self.add_architecture()
+
+    def write_header_to_file(self) -> None:
+        if self.state is not WriterState.EMPTY:
+            raise ValueError(f'Expected output file to be empty, got {self.state}')
+
+        self._write_packed("<I", GGUF_MAGIC, skip_pack_prefix = True)
+        self._write_packed("I", GGUF_VERSION)
+        self._write_packed("Q", self.ti_data_count)
+        self._write_packed("Q", self.kv_data_count)
+        self.flush()
+        self.state = WriterState.HEADER
+
+    def write_kv_data_to_file(self) -> None:
+        if self.state is not WriterState.HEADER:
+            raise ValueError(f'Expected output file to contain the header, got {self.state}')
+
+        self.fout.write(self.kv_data)
+        self.flush()
+        self.state = WriterState.KV_DATA
+
+    def write_ti_data_to_file(self) -> None:
+        if self.state is not WriterState.KV_DATA:
+            raise ValueError(f'Expected output file to contain KV data, got {self.state}')
+
+        self.fout.write(self.ti_data)
+        self.flush()
+        self.state = WriterState.TI_DATA
+
+    def add_key(self, key: str) -> None:
+        self.add_val(key, GGUFValueType.STRING, add_vtype=False)
+
+    def add_uint8(self, key: str, val: int) -> None:
+        self.add_key(key)
+        self.add_val(val, GGUFValueType.UINT8)
+
+    def add_int8(self, key: str, val: int) -> None:
+        self.add_key(key)
+        self.add_val(val, GGUFValueType.INT8)
+
+    def add_uint16(self, key: str, val: int) -> None:
+        self.add_key(key)
+        self.add_val(val, GGUFValueType.UINT16)
+
+    def add_int16(self, key: str, val: int) -> None:
+        self.add_key(key)
+        self.add_val(val, GGUFValueType.INT16)
+
+    def add_uint32(self, key: str, val: int) -> None:
+        self.add_key(key)
+        self.add_val(val, GGUFValueType.UINT32)
+
+    def add_int32(self, key: str, val: int) -> None:
+        self.add_key(key)
+        self.add_val(val, GGUFValueType.INT32)
+
+    def add_float32(self, key: str, val: float) -> None:
+        self.add_key(key)
+        self.add_val(val, GGUFValueType.FLOAT32)
+
+    def add_uint64(self, key: str, val: int) -> None:
+        self.add_key(key)
+        self.add_val(val, GGUFValueType.UINT64)
+
+    def add_int64(self, key: str, val: int) -> None:
+        self.add_key(key)
+        self.add_val(val, GGUFValueType.INT64)
+
+    def add_float64(self, key: str, val: float) -> None:
+        self.add_key(key)
+        self.add_val(val, GGUFValueType.FLOAT64)
+
+    def add_bool(self, key: str, val: bool) -> None:
+        self.add_key(key)
+        self.add_val(val, GGUFValueType.BOOL)
+
+    def add_string(self, key: str, val: str) -> None:
+        if not val:
+            return
+        self.add_key(key)
+        self.add_val(val, GGUFValueType.STRING)
+
+    def add_array(self, key: str, val: Sequence[Any]) -> None:
+        if not isinstance(val, Sequence):
+            raise ValueError("Value must be a sequence for array type")
+
+        self.add_key(key)
+        self.add_val(val, GGUFValueType.ARRAY)
+
+    def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None:
+        if vtype is None:
+            vtype = GGUFValueType.get_type(val)
+
+        if add_vtype:
+            self.kv_data += self._pack("I", vtype)
+            self.kv_data_count += 1
+
+        pack_fmt = self._simple_value_packing.get(vtype)
+        if pack_fmt is not None:
+            self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL)
+        elif vtype == GGUFValueType.STRING:
+            encoded_val = val.encode("utf8") if isinstance(val, str) else val
+            self.kv_data += self._pack("Q", len(encoded_val))
+            self.kv_data += encoded_val
+        elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val:
+            ltype = GGUFValueType.get_type(val[0])
+            if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
+                raise ValueError("All items in a GGUF array should be of the same type")
+            self.kv_data += self._pack("I", ltype)
+            self.kv_data += self._pack("Q", len(val))
+            for item in val:
+                self.add_val(item, add_vtype=False)
+        else:
+            raise ValueError("Invalid GGUF metadata value type or value")
+
+    @staticmethod
+    def ggml_pad(x: int, n: int) -> int:
+        return ((x + n - 1) // n) * n
+
+    def add_tensor_info(
+        self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32],
+        tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None,
+    ) -> None:
+        if self.state is not WriterState.EMPTY:
+            raise ValueError(f'Expected output file to be empty, got {self.state}')
+
+        if raw_dtype is None and tensor_dtype not in (np.float32, np.float16):
+            raise ValueError("Only F32 and F16 tensors are supported for now")
+
+        encoded_name = name.encode("utf8")
+        self.ti_data += self._pack("Q", len(encoded_name))
+        self.ti_data += encoded_name
+        n_dims = len(tensor_shape)
+        self.ti_data += self._pack("I", n_dims)
+        for i in range(n_dims):
+            self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
+        if raw_dtype is None:
+            dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16
+        else:
+            dtype = raw_dtype
+        self.ti_data += self._pack("I", dtype)
+        self.ti_data += self._pack("Q", self.offset_tensor)
+        self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
+        self.ti_data_count += 1
+
+    def add_tensor(
+        self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
+        raw_dtype: GGMLQuantizationType | None = None,
+    ) -> None:
+        if self.endianess == GGUFEndian.BIG:
+            tensor.byteswap(inplace=True)
+        if self.use_temp_file and self.temp_file is None:
+            fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024)
+            fp.seek(0)
+            self.temp_file = fp
+
+        shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
+        self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype)
+
+        if self.temp_file is None:
+            self.tensors.append(tensor)
+            return
+
+        tensor.tofile(self.temp_file)
+        self.write_padding(self.temp_file, tensor.nbytes)
+
+    def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None:
+        pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
+        if pad != 0:
+            fp.write(bytes([0] * pad))
+
+    def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
+        if self.state is not WriterState.TI_DATA:
+            raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
+
+        if self.endianess == GGUFEndian.BIG:
+            tensor.byteswap(inplace=True)
+        self.write_padding(self.fout, self.fout.tell())
+        tensor.tofile(self.fout)
+        self.write_padding(self.fout, tensor.nbytes)
+
+    def write_tensors_to_file(self) -> None:
+        self.write_ti_data_to_file()
+
+        self.write_padding(self.fout, self.fout.tell())
+
+        if self.temp_file is None:
+            while True:
+                try:
+                    tensor = self.tensors.pop(0)
+                except IndexError:
+                    break
+                tensor.tofile(self.fout)
+                self.write_padding(self.fout, tensor.nbytes)
+            return
+
+        self.temp_file.seek(0)
+
+        shutil.copyfileobj(self.temp_file, self.fout)
+        self.flush()
+        self.temp_file.close()
+
+    def flush(self) -> None:
+        self.fout.flush()
+
+    def close(self) -> None:
+        self.fout.close()
+
+    def add_architecture(self) -> None:
+        self.add_string(Keys.General.ARCHITECTURE, self.arch)
+
+    def add_author(self, author: str) -> None:
+        self.add_string(Keys.General.AUTHOR, author)
+
+    def add_tensor_data_layout(self, layout: str) -> None:
+        self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
+
+    def add_url(self, url: str) -> None:
+        self.add_string(Keys.General.URL, url)
+
+    def add_description(self, description: str) -> None:
+        self.add_string(Keys.General.DESCRIPTION, description)
+
+    def add_source_url(self, url: str) -> None:
+        self.add_string(Keys.General.SOURCE_URL, url)
+
+    def add_source_hf_repo(self, repo: str) -> None:
+        self.add_string(Keys.General.SOURCE_HF_REPO, repo)
+
+    def add_file_type(self, ftype: int) -> None:
+        self.add_uint32(Keys.General.FILE_TYPE, ftype)
+
+    def add_name(self, name: str) -> None:
+        self.add_string(Keys.General.NAME, name)
+
+    def add_quantization_version(self, quantization_version: GGMLQuantizationType) -> None:
+        self.add_uint32(
+            Keys.General.QUANTIZATION_VERSION, quantization_version)
+
+    def add_custom_alignment(self, alignment: int) -> None:
+        self.data_alignment = alignment
+        self.add_uint32(Keys.General.ALIGNMENT, alignment)
+
+    def add_context_length(self, length: int) -> None:
+        self.add_uint32(Keys.LLM.CONTEXT_LENGTH.format(arch=self.arch), length)
+
+    def add_embedding_length(self, length: int) -> None:
+        self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length)
+
+    def add_block_count(self, length: int) -> None:
+        self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)
+
+    def add_feed_forward_length(self, length: int) -> None:
+        self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
+
+    def add_parallel_residual(self, use: bool) -> None:
+        self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
+
+    def add_head_count(self, count: int) -> None:
+        self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
+
+    def add_head_count_kv(self, count: int) -> None:
+        self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
+
+    def add_max_alibi_bias(self, bias: float) -> None:
+        self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias)
+
+    def add_clamp_kqv(self, value: float) -> None:
+        self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
+
+    def add_layer_norm_eps(self, value: float) -> None:
+        self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
+
+    def add_layer_norm_rms_eps(self, value: float) -> None:
+        self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)
+
+    def add_rope_dimension_count(self, count: int) -> None:
+        self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
+
+    def add_rope_freq_base(self, value: float) -> None:
+        self.add_float32(Keys.Rope.FREQ_BASE.format(arch=self.arch), value)
+
+    def add_rope_scaling_type(self, value: RopeScalingType) -> None:
+        self.add_string(Keys.Rope.SCALING_TYPE.format(arch=self.arch), value.value)
+
+    def add_rope_scaling_factor(self, value: float) -> None:
+        self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value)
+
+    def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
+        self.add_uint32(Keys.Rope.SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)
+
+    def add_rope_scaling_finetuned(self, value: bool) -> None:
+        self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
+
+    def add_tokenizer_model(self, model: str) -> None:
+        self.add_string(Keys.Tokenizer.MODEL, model)
+
+    def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
+        self.add_array(Keys.Tokenizer.LIST, tokens)
+
+    def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
+        self.add_array(Keys.Tokenizer.MERGES, merges)
+
+    def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None:
+        self.add_array(Keys.Tokenizer.TOKEN_TYPE, types)
+
+    def add_token_scores(self, scores: Sequence[float]) -> None:
+        self.add_array(Keys.Tokenizer.SCORES, scores)
+
+    def add_bos_token_id(self, id: int) -> None:
+        self.add_uint32(Keys.Tokenizer.BOS_ID, id)
+
+    def add_eos_token_id(self, id: int) -> None:
+        self.add_uint32(Keys.Tokenizer.EOS_ID, id)
+
+    def add_unk_token_id(self, id: int) -> None:
+        self.add_uint32(Keys.Tokenizer.UNK_ID, id)
+
+    def add_sep_token_id(self, id: int) -> None:
+        self.add_uint32(Keys.Tokenizer.SEP_ID, id)
+
+    def add_pad_token_id(self, id: int) -> None:
+        self.add_uint32(Keys.Tokenizer.PAD_ID, id)
+
+    def add_add_bos_token(self, value: bool) -> None:
+        self.add_bool(Keys.Tokenizer.ADD_BOS, value)
+
+    def add_add_eos_token(self, value: bool) -> None:
+        self.add_bool(Keys.Tokenizer.ADD_EOS, value)
+
+    def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
+        pack_prefix = ''
+        if not skip_pack_prefix:
+            pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>'
+        return struct.pack(f'{pack_prefix}{fmt}', value)
+
+    def _write_packed(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None:
+        self.fout.write(self._pack(fmt, value, skip_pack_prefix))
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
new file mode 100644 (file)
index 0000000..22ad8b8
--- /dev/null
@@ -0,0 +1,257 @@
+from __future__ import annotations
+
+from typing import Sequence
+
+from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES
+
+
+class TensorNameMap:
+    mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
+        # Token embeddings
+        MODEL_TENSOR.TOKEN_EMBD: (
+            "gpt_neox.embed_in",                         # gptneox
+            "transformer.wte",                           # gpt2 gpt-j mpt refact
+            "transformer.word_embeddings",               # falcon
+            "word_embeddings",                           # bloom
+            "model.embed_tokens",                        # llama-hf
+            "tok_embeddings",                            # llama-pth
+            "embeddings.word_embeddings",                # bert
+            "language_model.embedding.word_embeddings",  # persimmon
+        ),
+
+        # Token type embeddings
+        MODEL_TENSOR.TOKEN_TYPES: (
+            "embeddings.token_type_embeddings",  # bert
+        ),
+
+        # Normalization of token embeddings
+        MODEL_TENSOR.TOKEN_EMBD_NORM: (
+            "word_embeddings_layernorm",  # bloom
+        ),
+
+        # Position embeddings
+        MODEL_TENSOR.POS_EMBD: (
+            "transformer.wpe",                 # gpt2
+            "embeddings.position_embeddings",  # bert
+        ),
+
+        # Output
+        MODEL_TENSOR.OUTPUT: (
+            "embed_out",                 # gptneox
+            "lm_head",                   # gpt2 mpt falcon llama-hf baichuan
+            "output",                    # llama-pth bloom
+            "word_embeddings_for_head",  # persimmon
+        ),
+
+        # Output norm
+        MODEL_TENSOR.OUTPUT_NORM: (
+            "gpt_neox.final_layer_norm",               # gptneox
+            "transformer.ln_f",                        # gpt2 gpt-j falcon
+            "model.norm",                              # llama-hf baichuan
+            "norm",                                    # llama-pth
+            "embeddings.LayerNorm",                    # bert
+            "transformer.norm_f",                      # mpt
+            "ln_f",                                    # refact bloom
+            "language_model.encoder.final_layernorm",  # persimmon
+        ),
+
+        # Rope frequencies
+        MODEL_TENSOR.ROPE_FREQS: (
+            "rope.freqs",  # llama-pth
+        ),
+    }
+
+    block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
+        # Attention norm
+        MODEL_TENSOR.ATTN_NORM: (
+            "gpt_neox.layers.{bid}.input_layernorm",                # gptneox
+            "transformer.h.{bid}.ln_1",                             # gpt2 gpt-j refact
+            "transformer.blocks.{bid}.norm_1",                      # mpt
+            "transformer.h.{bid}.input_layernorm",                  # falcon7b
+            "h.{bid}.input_layernorm",                              # bloom
+            "transformer.h.{bid}.ln_mlp",                           # falcon40b
+            "model.layers.{bid}.input_layernorm",                   # llama-hf
+            "layers.{bid}.attention_norm",                          # llama-pth
+            "encoder.layer.{bid}.attention.output.LayerNorm",       # bert
+            "language_model.encoder.layers.{bid}.input_layernorm",  # persimmon
+            "model.layers.{bid}.ln1",                               # yi
+        ),
+
+        # Attention norm 2
+        MODEL_TENSOR.ATTN_NORM_2: (
+            "transformer.h.{bid}.ln_attn",  # falcon40b
+        ),
+
+        # Attention query-key-value
+        MODEL_TENSOR.ATTN_QKV: (
+            "gpt_neox.layers.{bid}.attention.query_key_value",                     # gptneox
+            "transformer.h.{bid}.attn.c_attn",                                     # gpt2
+            "transformer.blocks.{bid}.attn.Wqkv",                                  # mpt
+            "transformer.h.{bid}.self_attention.query_key_value",                  # falcon
+            "h.{bid}.self_attention.query_key_value",                              # bloom
+            "language_model.encoder.layers.{bid}.self_attention.query_key_value",  # persimmon
+        ),
+
+        # Attention query
+        MODEL_TENSOR.ATTN_Q: (
+            "model.layers.{bid}.self_attn.q_proj",       # llama-hf
+            "layers.{bid}.attention.wq",                 # llama-pth
+            "encoder.layer.{bid}.attention.self.query",  # bert
+            "transformer.h.{bid}.attn.q_proj",           # gpt-j
+        ),
+
+        # Attention key
+        MODEL_TENSOR.ATTN_K: (
+            "model.layers.{bid}.self_attn.k_proj",     # llama-hf
+            "layers.{bid}.attention.wk",               # llama-pth
+            "encoder.layer.{bid}.attention.self.key",  # bert
+            "transformer.h.{bid}.attn.k_proj",         # gpt-j
+        ),
+
+        # Attention value
+        MODEL_TENSOR.ATTN_V: (
+            "model.layers.{bid}.self_attn.v_proj",       # llama-hf
+            "layers.{bid}.attention.wv",                 # llama-pth
+            "encoder.layer.{bid}.attention.self.value",  # bert
+            "transformer.h.{bid}.attn.v_proj",           # gpt-j
+        ),
+
+        # Attention output
+        MODEL_TENSOR.ATTN_OUT: (
+            "gpt_neox.layers.{bid}.attention.dense",                     # gptneox
+            "transformer.h.{bid}.attn.c_proj",                           # gpt2 refact
+            "transformer.blocks.{bid}.attn.out_proj",                    # mpt
+            "transformer.h.{bid}.self_attention.dense",                  # falcon
+            "h.{bid}.self_attention.dense",                              # bloom
+            "model.layers.{bid}.self_attn.o_proj",                       # llama-hf
+            "layers.{bid}.attention.wo",                                 # llama-pth
+            "encoder.layer.{bid}.attention.output.dense",                # bert
+            "transformer.h.{bid}.attn.out_proj",                         # gpt-j
+            "language_model.encoder.layers.{bid}.self_attention.dense",  # persimmon
+        ),
+
+        # Rotary embeddings
+        MODEL_TENSOR.ATTN_ROT_EMBD: (
+            "model.layers.{bid}.self_attn.rotary_emb.inv_freq",   # llama-hf
+            "layers.{bid}.attention.inner_attention.rope.freqs",  # llama-pth
+        ),
+
+        # Feed-forward norm
+        MODEL_TENSOR.FFN_NORM: (
+            "gpt_neox.layers.{bid}.post_attention_layernorm",                # gptneox
+            "transformer.h.{bid}.ln_2",                                      # gpt2 refact
+            "h.{bid}.post_attention_layernorm",                              # bloom
+            "transformer.blocks.{bid}.norm_2",                               # mpt
+            "model.layers.{bid}.post_attention_layernorm",                   # llama-hf
+            "layers.{bid}.ffn_norm",                                         # llama-pth
+            "encoder.layer.{bid}.output.LayerNorm",                          # bert
+            "language_model.encoder.layers.{bid}.post_attention_layernorm",  # persimmon
+            "model.layers.{bid}.ln2",                                        # yi
+        ),
+
+        # Feed-forward up
+        MODEL_TENSOR.FFN_UP: (
+            "gpt_neox.layers.{bid}.mlp.dense_h_to_4h",                # gptneox
+            "transformer.h.{bid}.mlp.c_fc",                           # gpt2
+            "transformer.blocks.{bid}.ffn.up_proj",                   # mpt
+            "transformer.h.{bid}.mlp.dense_h_to_4h",                  # falcon
+            "h.{bid}.mlp.dense_h_to_4h",                              # bloom
+            "model.layers.{bid}.mlp.up_proj",                         # llama-hf refact
+            "layers.{bid}.feed_forward.w3",                           # llama-pth
+            "encoder.layer.{bid}.intermediate.dense",                 # bert
+            "transformer.h.{bid}.mlp.fc_in",                          # gpt-j
+            "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h",  # persimmon
+        ),
+
+        # Feed-forward gate
+        MODEL_TENSOR.FFN_GATE: (
+            "model.layers.{bid}.mlp.gate_proj",  # llama-hf refact
+            "layers.{bid}.feed_forward.w1",      # llama-pth
+        ),
+
+        # Feed-forward down
+        MODEL_TENSOR.FFN_DOWN: (
+            "gpt_neox.layers.{bid}.mlp.dense_4h_to_h",                # gptneox
+            "transformer.h.{bid}.mlp.c_proj",                         # gpt2 refact
+            "transformer.blocks.{bid}.ffn.down_proj",                 # mpt
+            "transformer.h.{bid}.mlp.dense_4h_to_h",                  # falcon
+            "h.{bid}.mlp.dense_4h_to_h",                              # bloom
+            "model.layers.{bid}.mlp.down_proj",                       # llama-hf
+            "layers.{bid}.feed_forward.w2",                           # llama-pth
+            "encoder.layer.{bid}.output.dense",                       # bert
+            "transformer.h.{bid}.mlp.fc_out",                         # gpt-j
+            "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h",  # persimmon
+        ),
+
+        MODEL_TENSOR.ATTN_Q_NORM: (
+            "language_model.encoder.layers.{bid}.self_attention.q_layernorm",
+        ),
+
+        MODEL_TENSOR.ATTN_K_NORM: (
+            "language_model.encoder.layers.{bid}.self_attention.k_layernorm",
+        ),
+
+        MODEL_TENSOR.ROPE_FREQS: (
+            "language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq",  # persimmon
+        ),
+    }
+
+    mapping: dict[str, tuple[MODEL_TENSOR, str]]
+
+    def __init__(self, arch: MODEL_ARCH, n_blocks: int):
+        self.mapping = {}
+        for tensor, keys in self.mappings_cfg.items():
+            if tensor not in MODEL_TENSORS[arch]:
+                continue
+            tensor_name = TENSOR_NAMES[tensor]
+            self.mapping[tensor_name] = (tensor, tensor_name)
+            for key in keys:
+                self.mapping[key] = (tensor, tensor_name)
+        for bid in range(n_blocks):
+            for tensor, keys in self.block_mappings_cfg.items():
+                if tensor not in MODEL_TENSORS[arch]:
+                    continue
+                tensor_name = TENSOR_NAMES[tensor].format(bid = bid)
+                self.mapping[tensor_name] = (tensor, tensor_name)
+                for key in keys:
+                    key = key.format(bid = bid)
+                    self.mapping[key] = (tensor, tensor_name)
+
+    def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
+        result = self.mapping.get(key)
+        if result is not None:
+            return result
+        for suffix in try_suffixes:
+            if key.endswith(suffix):
+                result = self.mapping.get(key[:-len(suffix)])
+                if result is not None:
+                    return result[0], result[1] + suffix
+        return None
+
+    def get_name(self, key: str, try_suffixes: Sequence[str] = ()) -> str | None:
+        result = self.get_type_and_name(key, try_suffixes = try_suffixes)
+        if result is None:
+            return None
+        return result[1]
+
+    def get_type(self, key: str, try_suffixes: Sequence[str] = ()) -> MODEL_TENSOR | None:
+        result = self.get_type_and_name(key, try_suffixes = try_suffixes)
+        if result is None:
+            return None
+        return result[0]
+
+    def __getitem__(self, key: str) -> str:
+        try:
+            return self.mapping[key][1]
+        except KeyError:
+            raise KeyError(key)
+
+    def __contains__(self, key: str) -> bool:
+        return key in self.mapping
+
+    def __repr__(self) -> str:
+        return repr(self.mapping)
+
+
+def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> TensorNameMap:
+    return TensorNameMap(arch, n_blocks)
diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py
new file mode 100644 (file)
index 0000000..71192a9
--- /dev/null
@@ -0,0 +1,164 @@
+from __future__ import annotations
+
+import json
+import os
+import sys
+from pathlib import Path
+from typing import Any, Callable
+
+from .gguf_writer import GGUFWriter
+
+
+class SpecialVocab:
+    merges: list[str]
+    add_special_token: dict[str, bool]
+    special_token_ids: dict[str, int]
+
+    def __init__(
+        self, path: str | os.PathLike[str], load_merges: bool = False,
+        special_token_types: tuple[str, ...] | None = None,
+        n_vocab: int | None = None,
+    ):
+        self.special_token_ids = {}
+        self.add_special_token = {}
+        self.n_vocab = n_vocab
+        self.load_merges = load_merges
+        self.merges = []
+        if special_token_types is not None:
+            self.special_token_types = special_token_types
+        else:
+            self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad')
+        self._load(Path(path))
+
+    def __repr__(self) -> str:
+        return '<SpecialVocab with {} merges, special tokens {}, add special tokens {}>'.format(
+            len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset",
+        )
+
+    def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
+        if self.merges:
+            if not quiet:
+                print(f'gguf: Adding {len(self.merges)} merge(s).')
+            gw.add_token_merges(self.merges)
+        elif self.load_merges:
+            print(
+                'gguf: WARNING: Adding merges requested but no merges found, output may be non-functional.',
+                file = sys.stderr,
+            )
+        for typ, tokid in self.special_token_ids.items():
+            id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
+            if id_handler is None:
+                print(
+                    f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping',
+                    file = sys.stderr,
+                )
+                continue
+            if not quiet:
+                print(f'gguf: Setting special token type {typ} to {tokid}')
+            id_handler(tokid)
+        for typ, value in self.add_special_token.items():
+            add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None)
+            if add_handler is None:
+                print(
+                    f'gguf: WARNING: No handler for add_{typ}_token with value {value} - skipping',
+                    file = sys.stderr,
+                )
+                continue
+            if not quiet:
+                print(f'gguf: Setting add_{typ}_token to {value}')
+            add_handler(value)
+
+    def _load(self, path: Path) -> None:
+        self._try_load_from_tokenizer_json(path)
+        self._try_load_from_config_json(path)
+        if self.load_merges and not self.merges:
+            self._try_load_merges_txt(path)
+
+    def _try_load_merges_txt(self, path: Path) -> bool:
+        merges_file = path / 'merges.txt'
+        if not merges_file.is_file():
+            return False
+        with open(merges_file, 'r') as fp:
+            first_line = next(fp, '').strip()
+            if not first_line.startswith('#'):
+                fp.seek(0)
+                line_num = 0
+            else:
+                line_num = 1
+            merges = []
+            for line in fp:
+                line_num += 1
+                line = line.strip()
+                if not line:
+                    continue
+                parts = line.split(None, 3)
+                if len(parts) != 2:
+                    print(
+                        f'gguf: WARNING: {merges_file.name}: Line {line_num}: Entry malformed, ignoring',
+                        file = sys.stderr,
+                    )
+                    continue
+                merges.append(f'{parts[0]} {parts[1]}')
+        self.merges = merges
+        return True
+
+    def _set_special_token(self, typ: str, tid: Any) -> None:
+        if not isinstance(tid, int) or tid < 0:
+            return
+        if self.n_vocab is None or tid < self.n_vocab:
+            if typ in self.special_token_ids:
+                return
+            self.special_token_ids[typ] = tid
+            return
+        print(
+            f'gguf: WARNING: Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping',
+            file = sys.stderr,
+        )
+
+    def _try_load_from_tokenizer_json(self, path: Path) -> bool:
+        tokenizer_file = path / 'tokenizer.json'
+        if not tokenizer_file.is_file():
+            return False
+        with open(tokenizer_file, encoding = 'utf-8') as f:
+            tokenizer = json.load(f)
+        if self.load_merges:
+            merges = tokenizer.get('model', {}).get('merges')
+            if isinstance(merges, list) and merges and isinstance(merges[0], str):
+                self.merges = merges
+        tokenizer_config_file = path / 'tokenizer_config.json'
+        added_tokens = tokenizer.get('added_tokens')
+        if added_tokens is None or not tokenizer_config_file.is_file():
+            return True
+        with open(tokenizer_config_file, encoding = 'utf-8') as f:
+            tokenizer_config = json.load(f)
+        for typ in self.special_token_types:
+            add_entry = tokenizer_config.get(f'add_{typ}_token')
+            if isinstance(add_entry, bool):
+                self.add_special_token[typ] = add_entry
+            entry = tokenizer_config.get(f'{typ}_token')
+            if isinstance(entry, str):
+                tc_content = entry
+            elif isinstance(entry, dict):
+                entry_content = entry.get('content')
+                if not isinstance(entry_content, str):
+                    continue
+                tc_content = entry_content
+            else:
+                continue
+            # We only need the first match here.
+            maybe_token_id = next(
+                (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content),
+                None,
+            )
+            self._set_special_token(typ, maybe_token_id)
+        return True
+
+    def _try_load_from_config_json(self, path: Path) -> bool:
+        config_file = path / 'config.json'
+        if not config_file.is_file():
+            return False
+        with open(config_file, encoding = 'utf-8') as f:
+            config = json.load(f)
+        for typ in self.special_token_types:
+            self._set_special_token(typ, config.get(f'{typ}_token_id'))
+        return True
index c6cb2c37a0e0a1587a18fac5b8b4e0d117318306..624e1cda628e1adc8b261f66709366884570398f 100644 (file)
@@ -1,11 +1,12 @@
 [tool.poetry]
 name = "gguf"
-version = "0.4.6"
+version = "0.5.0"
 description = "Write ML models in GGUF for GGML"
 authors = ["GGML <ggml@ggml.ai>"]
 packages = [
     {include = "gguf"},
     {include = "gguf/py.typed"},
+    {include = "scripts"},
 ]
 readme = "README.md"
 homepage = "https://ggml.ai"
@@ -27,3 +28,8 @@ pytest = "^5.2"
 [build-system]
 requires = ["poetry-core>=1.0.0"]
 build-backend = "poetry.core.masonry.api"
+
+[tool.poetry.scripts]
+gguf-convert-endian = "scripts:gguf_convert_endian_entrypoint"
+gguf-dump = "scripts:gguf_dump_entrypoint"
+gguf-set-metadata = "scripts:gguf_set_metadata_entrypoint"
diff --git a/gguf-py/scripts/__init__.py b/gguf-py/scripts/__init__.py
new file mode 100644 (file)
index 0000000..77132db
--- /dev/null
@@ -0,0 +1,12 @@
+import os
+
+from importlib import import_module
+
+
+os.environ["NO_LOCAL_GGUF"] = "TRUE"
+
+gguf_convert_endian_entrypoint = import_module("scripts.gguf-convert-endian").main
+gguf_dump_entrypoint           = import_module("scripts.gguf-dump").main
+gguf_set_metadata_entrypoint   = import_module("scripts.gguf-set-metadata").main
+
+del import_module, os
diff --git a/gguf-py/scripts/gguf-convert-endian.py b/gguf-py/scripts/gguf-convert-endian.py
new file mode 100755 (executable)
index 0000000..b79d86e
--- /dev/null
@@ -0,0 +1,113 @@
+#!/usr/bin/env python3
+from __future__ import annotations
+
+import argparse
+import os
+import sys
+from pathlib import Path
+
+import numpy as np
+
+# Necessary to load the local gguf package
+if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
+    sys.path.insert(0, str(Path(__file__).parent.parent))
+
+import gguf
+
+
+def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None:
+    if np.uint32(1) == np.uint32(1).newbyteorder("<"):
+        # Host is little endian
+        host_endian = "little"
+        swapped_endian = "big"
+    else:
+        # Sorry PDP or other weird systems that don't use BE or LE.
+        host_endian = "big"
+        swapped_endian = "little"
+    if reader.byte_order == "S":
+        file_endian = swapped_endian
+    else:
+        file_endian = host_endian
+    if args.order == "native":
+        order = host_endian
+    print(f"* Host is {host_endian.upper()} endian, GGUF file seems to be {file_endian.upper()} endian")
+    if file_endian == order:
+        print(f"* File is already {order.upper()} endian. Nothing to do.")
+        sys.exit(0)
+    print("* Checking tensors for conversion compatibility")
+    for tensor in reader.tensors:
+        if tensor.tensor_type not in (
+            gguf.GGMLQuantizationType.F32,
+            gguf.GGMLQuantizationType.F16,
+            gguf.GGMLQuantizationType.Q8_0,
+        ):
+            raise ValueError(f"Cannot handle type {tensor.tensor_type.name} for tensor {repr(tensor.name)}")
+    print(f"* Preparing to convert from {file_endian.upper()} to {order.upper()}")
+    if args.dry_run:
+        return
+    print("\n*** Warning *** Warning *** Warning **")
+    print("* This conversion process may damage the file. Ensure you have a backup.")
+    if order != host_endian:
+        print("* Requested endian differs from host, you will not be able to load the model on this machine.")
+    print("* The file will be modified immediately, so if conversion fails or is interrupted")
+    print("* the file will be corrupted. Enter exactly YES if you are positive you want to proceed:")
+    response = input("YES, I am sure> ")
+    if response != "YES":
+        print("You didn't enter YES. Okay then, see ya!")
+        sys.exit(0)
+    print(f"\n* Converting fields ({len(reader.fields)})")
+    for idx, field in enumerate(reader.fields.values()):
+        print(f"- {idx:4}: Converting field {repr(field.name)}, part count: {len(field.parts)}")
+        for part in field.parts:
+            part.byteswap(inplace=True)
+    print(f"\n* Converting tensors ({len(reader.tensors)})")
+    for idx, tensor in enumerate(reader.tensors):
+        print(
+            f"  - {idx:4}: Converting tensor {repr(tensor.name)}, type={tensor.tensor_type.name}, "
+            f"elements={tensor.n_elements}... ",
+            end="",
+        )
+        tensor_type = tensor.tensor_type
+        for part in tensor.field.parts:
+            part.byteswap(inplace=True)
+        if tensor_type != gguf.GGMLQuantizationType.Q8_0:
+            tensor.data.byteswap(inplace=True)
+            print()
+            continue
+        # A Q8_0 block consists of a f16 delta followed by 32 int8 quants, so 34 bytes
+        block_size = 34
+        n_blocks = len(tensor.data) // block_size
+        for block_num in range(n_blocks):
+            block_offs = block_num * block_size
+            # I know I said f16, but it doesn't matter here - any simple 16 bit type works.
+            delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16)
+            delta.byteswap(inplace=True)
+            if block_num % 100000 == 0:
+                print(f"[{(n_blocks - block_num) // 1000}K]", end="")
+                sys.stdout.flush()
+        print()
+    print("* Completion")
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(description="Convert GGUF file byte order")
+    parser.add_argument(
+        "model", type=str,
+        help="GGUF format model filename",
+    )
+    parser.add_argument(
+        "order", type=str, choices=['big', 'little', 'native'],
+        help="Requested byte order",
+    )
+    parser.add_argument(
+        "--dry-run", action="store_true",
+        help="Don't actually change anything",
+    )
+    args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
+    print(f'* Loading: {args.model}')
+    reader = gguf.GGUFReader(args.model, 'r' if args.dry_run else 'r+')
+    convert_byteorder(reader, args)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/gguf-py/scripts/gguf-dump.py b/gguf-py/scripts/gguf-dump.py
new file mode 100755 (executable)
index 0000000..5141873
--- /dev/null
@@ -0,0 +1,116 @@
+#!/usr/bin/env python3
+from __future__ import annotations
+
+import argparse
+import os
+import sys
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+
+# Necessary to load the local gguf package
+if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
+    sys.path.insert(0, str(Path(__file__).parent.parent))
+
+from gguf import GGUFReader, GGUFValueType  # noqa: E402
+
+
+def get_file_host_endian(reader: GGUFReader) -> tuple[str, str]:
+    host_endian = 'LITTLE' if np.uint32(1) == np.uint32(1).newbyteorder("<") else 'BIG'
+    if reader.byte_order == 'S':
+        file_endian = 'BIG' if host_endian == 'LITTLE' else 'LITTLE'
+    else:
+        file_endian = host_endian
+    return (host_endian, file_endian)
+
+
+# For more information about what field.parts and field.data represent,
+# please see the comments in the modify_gguf.py example.
+def dump_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
+    host_endian, file_endian = get_file_host_endian(reader)
+    print(f'* File is {file_endian} endian, script is running on a {host_endian} endian host.')
+    print(f'\n* Dumping {len(reader.fields)} key/value pair(s)')
+    for n, field in enumerate(reader.fields.values(), 1):
+        if not field.types:
+            pretty_type = 'N/A'
+        elif field.types[0] == GGUFValueType.ARRAY:
+            nest_count = len(field.types) - 1
+            pretty_type = '[' * nest_count + str(field.types[-1].name) + ']' * nest_count
+        else:
+            pretty_type = str(field.types[-1].name)
+        print(f'  {n:5}: {pretty_type:10} | {len(field.data):8} | {field.name}', end = '')
+        if len(field.types) == 1:
+            curr_type = field.types[0]
+            if curr_type == GGUFValueType.STRING:
+                print(' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf8')[:60])), end = '')
+            elif field.types[0] in reader.gguf_scalar_to_np:
+                print(' = {0}'.format(field.parts[-1][0]), end = '')
+        print()
+    if args.no_tensors:
+        return
+    print(f'\n* Dumping {len(reader.tensors)} tensor(s)')
+    for n, tensor in enumerate(reader.tensors, 1):
+        prettydims = ', '.join('{0:5}'.format(d) for d in list(tensor.shape) + [1] * (4 - len(tensor.shape)))
+        print(f'  {n:5}: {tensor.n_elements:10} | {prettydims} | {tensor.tensor_type.name:7} | {tensor.name}')
+
+
+def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None:
+    import json
+    host_endian, file_endian = get_file_host_endian(reader)
+    metadata: dict[str, Any] = {}
+    tensors: dict[str, Any] = {}
+    result = {
+        "filename": args.model,
+        "endian": file_endian,
+        "metadata": metadata,
+        "tensors": tensors,
+    }
+    for idx, field in enumerate(reader.fields.values()):
+        curr: dict[str, Any] = {
+            "index": idx,
+            "type": field.types[0].name if field.types else 'UNKNOWN',
+            "offset": field.offset,
+        }
+        metadata[field.name] = curr
+        if field.types[:1] == [GGUFValueType.ARRAY]:
+            curr["array_types"] = [t.name for t in field.types][1:]
+            if not args.json_array:
+                continue
+            itype = field.types[-1]
+            if itype == GGUFValueType.STRING:
+                curr["value"] = [str(bytes(field.parts[idx]), encoding="utf-8") for idx in field.data]
+            else:
+                curr["value"] = [pv for idx in field.data for pv in field.parts[idx].tolist()]
+        elif field.types[0] == GGUFValueType.STRING:
+            curr["value"] = str(bytes(field.parts[-1]), encoding="utf-8")
+        else:
+            curr["value"] = field.parts[-1].tolist()[0]
+    for idx, tensor in enumerate(reader.tensors):
+        tensors[tensor.name] = {
+            "index": idx,
+            "shape": tensor.shape.tolist(),
+            "type": tensor.tensor_type.name,
+            "offset": tensor.field.offset,
+        }
+    json.dump(result, sys.stdout)
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(description="Dump GGUF file metadata")
+    parser.add_argument("model",           type=str,            help="GGUF format model filename")
+    parser.add_argument("--no-tensors", action="store_true", help="Don't dump tensor metadata")
+    parser.add_argument("--json",       action="store_true", help="Produce JSON output")
+    parser.add_argument("--json-array", action="store_true", help="Include full array values in JSON output (long)")
+    args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
+    if not args.json:
+        print(f'* Loading: {args.model}')
+    reader = GGUFReader(args.model, 'r')
+    if args.json:
+        dump_metadata_json(reader, args)
+    else:
+        dump_metadata(reader, args)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/gguf-py/scripts/gguf-set-metadata.py b/gguf-py/scripts/gguf-set-metadata.py
new file mode 100755 (executable)
index 0000000..3ebdfa8
--- /dev/null
@@ -0,0 +1,90 @@
+#!/usr/bin/env python3
+import argparse
+import os
+import sys
+from pathlib import Path
+
+# Necessary to load the local gguf package
+if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
+    sys.path.insert(0, str(Path(__file__).parent.parent))
+
+from gguf import GGUFReader  # noqa: E402
+
+
+def minimal_example(filename: str) -> None:
+    reader = GGUFReader(filename, 'r+')
+    field = reader.fields['tokenizer.ggml.bos_token_id']
+    if field is None:
+        return
+    part_index = field.data[0]
+    field.parts[part_index][0] = 2  # Set tokenizer.ggml.bos_token_id to 2
+    #
+    # So what's this field.data thing? It's helpful because field.parts contains
+    # _every_ part of the GGUF field. For example, tokenizer.ggml.bos_token_id consists
+    # of:
+    #
+    #  Part index 0: Key length (27)
+    #  Part index 1: Key data ("tokenizer.ggml.bos_token_id")
+    #  Part index 2: Field type (4, the id for GGUFValueType.UINT32)
+    #  Part index 3: Field value
+    #
+    # Note also that each part is an NDArray slice, so even a part that
+    # is only a single value like the key length will be a NDArray of
+    # the key length type (numpy.uint32).
+    #
+    # The .data attribute in the Field is a list of relevant part indexes
+    # and doesn't contain internal GGUF details like the key length part.
+    # In this case, .data will be [3] - just the part index of the
+    # field value itself.
+
+
+def set_metadata(reader: GGUFReader, args: argparse.Namespace) -> None:
+    field = reader.get_field(args.key)
+    if field is None:
+        print(f'! Field {repr(args.key)} not found', file = sys.stderr)
+        sys.exit(1)
+    # Note that field.types is a list of types. This is because the GGUF
+    # format supports arrays. For example, an array of UINT32 would
+    # look like [GGUFValueType.ARRAY, GGUFValueType.UINT32]
+    handler = reader.gguf_scalar_to_np.get(field.types[0]) if field.types else None
+    if handler is None:
+        print(
+            f'! This tool only supports changing simple values, {repr(args.key)} has unsupported type {field.types}',
+            file = sys.stderr,
+        )
+        sys.exit(1)
+    current_value = field.parts[field.data[0]][0]
+    new_value = handler(args.value)
+    print(f'* Preparing to change field {repr(args.key)} from {current_value} to {new_value}')
+    if current_value == new_value:
+        print(f'- Key {repr(args.key)} already set to requested value {current_value}')
+        sys.exit(0)
+    if args.dry_run:
+        sys.exit(0)
+    if not args.force:
+        print('*** Warning *** Warning *** Warning **')
+        print('* Changing fields in a GGUF file can make it unusable. Proceed at your own risk.')
+        print('* Enter exactly YES if you are positive you want to proceed:')
+        response = input('YES, I am sure> ')
+        if response != 'YES':
+            print("You didn't enter YES. Okay then, see ya!")
+            sys.exit(0)
+    field.parts[field.data[0]][0] = new_value
+    print('* Field changed. Successful completion.')
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(description="Set a simple value in GGUF file metadata")
+    parser.add_argument("model",     type=str,            help="GGUF format model filename")
+    parser.add_argument("key",       type=str,            help="Metadata key to set")
+    parser.add_argument("value",     type=str,            help="Metadata value to set")
+    parser.add_argument("--dry-run", action="store_true", help="Don't actually change anything")
+    parser.add_argument("--force",   action="store_true", help="Change the field without confirmation")
+    args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
+    print(f'* Loading: {args.model}')
+    reader = GGUFReader(args.model, 'r' if args.dry_run else 'r+')
+    set_metadata(reader, args)
+
+
+if __name__ == '__main__':
+    main()
index 512531dd2a8f0a836fa46af6edc0c4f0e0e48667..0adeb7d55731a4962beafcdab64d307131a9a3c8 100644 (file)
@@ -1,7 +1,7 @@
-import gguf
+import gguf  # noqa: F401
 
 # TODO: add tests
 
 
-def test_write_gguf():
+def test_write_gguf() -> None:
     pass