]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
mpt : convert model weights part by part to save memory (#314)
authorHugo Rosenkranz-Costa <redacted>
Sun, 2 Jul 2023 16:05:24 +0000 (18:05 +0200)
committerGitHub <redacted>
Sun, 2 Jul 2023 16:05:24 +0000 (19:05 +0300)
* mpt : update conversion script to load model weights part by part

* mpt : add usage README

examples/mpt/README.md [new file with mode: 0644]
examples/mpt/convert-h5-to-ggml.py [changed mode: 0644->0755]

diff --git a/examples/mpt/README.md b/examples/mpt/README.md
new file mode 100644 (file)
index 0000000..39f46ba
--- /dev/null
@@ -0,0 +1,27 @@
+# MPT
+
+Ref: https://github.com/mosaicml/llm-foundry#mpt
+
+## Usage
+
+```bash
+# get the repo and build it
+git clone https://github.com/ggerganov/ggml
+cd ggml
+mkdir build && cd build
+cmake ..
+make -j
+
+# get the model from HuggingFace
+# be sure to have git-lfs installed
+git clone https://huggingface.co/mosaicml/mpt-30b
+
+# convert model to FP16
+python3 ../examples/mpt/convert-h5-to-ggml.py ./mpt-30b 1
+
+# run inference using FP16 precision
+./bin/mpt -m ./mpt-30b/ggml-model-f16.bin -p "I believe the meaning of life is" -t 8 -n 64
+
+# quantize the model to 5-bits using Q5_0 quantization
+./bin/mpt-quantize ./mpt-30b/ggml-model-f16.bin ./mpt-30b/ggml-model-q5_0.bin q5_0
+```
old mode 100644 (file)
new mode 100755 (executable)
index 0765011..ccd6459
@@ -1,13 +1,13 @@
-import sys
+import os
 import struct
-import json
-import numpy as np
-from transformers import AutoModelForCausalLM, AutoTokenizer
-import sentencepiece.sentencepiece_model_pb2 as model
+import sys
+
+import torch
+from transformers import AutoConfig, AutoTokenizer
+
 
 # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
 def bytes_to_unicode():
-
     """
     Returns list of utf-8 byte and a corresponding list of unicode strings.
     The reversible bpe codes work on unicode strings.
@@ -17,19 +17,36 @@ def bytes_to_unicode():
     To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
     And avoids mapping to whitespace/control characters the bpe code barfs on.
     """
-    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+    bs = (
+        list(range(ord("!"), ord("~") + 1))
+        + list(range(ord("¡"), ord("¬") + 1))
+        + list(range(ord("®"), ord("ÿ") + 1))
+    )
     cs = bs[:]
     n = 0
     for b in range(2**8):
         if b not in bs:
             bs.append(b)
-            cs.append(2**8+n)
+            cs.append(2**8 + n)
             n += 1
 
     cs = [chr(n) for n in cs]
 
     return dict(zip(bs, cs))
 
+
+def count_model_parts(dir_model: str) -> int:
+    """Returns the number of model parts in the model directory."""
+    num_parts = 0
+    for filename in os.listdir(dir_model):
+        if filename.startswith("pytorch_model-"):
+            num_parts += 1
+
+    if num_parts > 0:
+        print(f"Found {num_parts} model parts in {dir_model}")
+    return num_parts
+
+
 if len(sys.argv) < 3:
     print("Usage: convert-h5-to-ggml.py dir-model [use-f32]\n")
     print("  ftype == 0 -> float32")
@@ -39,11 +56,8 @@ if len(sys.argv) < 3:
 
 # output in the same directory as the model
 dir_model = sys.argv[1]
-fname_out = sys.argv[1] + "/ggml-model.bin"
-
-
-with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
-    hparams = json.load(f)
+# get number of model parts
+num_parts = count_model_parts(dir_model)
 
 # possible data types
 #   ftype == 0 -> float32
@@ -58,25 +72,15 @@ if len(sys.argv) > 2:
     if ftype < 0 or ftype > 1:
         print("Invalid ftype: " + str(ftype))
         sys.exit(1)
-    fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
+    fname_out = dir_model + "/ggml-model-" + ftype_str[ftype] + ".bin"
 
 
 tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
-model = AutoModelForCausalLM.from_pretrained(
-    dir_model, low_cpu_mem_usage=True, trust_remote_code=True
-)
-# print (model)
-
-# print(tokenizer.encode('I believe the meaning of life is'))
-
-list_vars = model.state_dict()
-for name in list_vars.keys():
-    print(name, list_vars[name].shape, list_vars[name].dtype)
+config = AutoConfig.from_pretrained(dir_model, trust_remote_code=True)
+hparams = config.to_dict()
 
 fout = open(fname_out, "wb")
 
-print(hparams)
-
 fout.write(struct.pack("i", 0x67676D6C))  # magic: ggml in hex
 fout.write(struct.pack("i", hparams["d_model"]))
 fout.write(struct.pack("i", hparams["max_seq_len"]))
@@ -94,19 +98,19 @@ encoder = tokenizer.vocab
 encoder.update(tokenizer.get_added_vocab())
 
 byte_encoder = bytes_to_unicode()
-byte_decoder = {v:k for k, v in byte_encoder.items()}
+byte_decoder = {v: k for k, v in byte_encoder.items()}
 
 counter = 0
 # sort by value
 for key in sorted(encoder, key=encoder.get):
     # workaround for key error when c not found
-    text=""
+    text = ""
     for c in key:
         if c not in byte_decoder:
             text += c
         else:
-            text += chr(byte_decoder[c] )
-    text = bytearray( text, encoding="utf-8" )
+            text += chr(byte_decoder[c])
+    text = bytearray(text, encoding="utf-8")
     fout.write(struct.pack("i", len(text)))
     fout.write(text)
     counter += 1
@@ -117,40 +121,47 @@ while counter < vocab_size:
     fout.write(text)
     counter += 1
 
-# assert counter == config.vocab_size
-
-for name in list_vars.keys():
-    data = list_vars[name].squeeze().numpy()
-    print("Processing variable: " + name + " with shape: ", data.shape)
-
-    n_dims = len(data.shape)
-
-    # ftype == 0 -> float32, ftype == 1 -> float16
-    ftype_cur = 0
-    if ftype != 0:
-        if name[-7:] == ".weight" and n_dims == 2:
-            print("  Converting to float16")
-            data = data.astype(np.float16)
+if num_parts == 0:
+    part_names = ("pytorch_model.bin",)
+else:
+    part_names = (
+        f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
+    )
+
+for part_name in part_names:
+    print(f"\n* Loading part: {part_name}")
+    model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
+
+    for name in model_part.keys():
+        data = model_part[name].squeeze()
+        n_dims = len(data.shape)
+
+        # ftype == 0 -> float32, ftype == 1 -> float16
+        # default type is fp32
+        ftype_cur = 0
+        if ftype == 1 and name[-7:] == ".weight" and n_dims > 1:
             ftype_cur = 1
-        else:
-            print("  Converting to float32")
-            data = data.astype(np.float32)
-            ftype_cur = 0
-    else:
-        if data.dtype != np.float32:
-            print("  Converting to float32")
-            data = data.astype(np.float32)
-            ftype_cur = 0
-
-    # header
-    str = name.encode("utf-8")
-    fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
-    for i in range(n_dims):
-        fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
-    fout.write(str)
-
-    # data
-    data.tofile(fout)
+        data = data.to(dtype=torch.float16 if ftype_cur == 1 else torch.float32).numpy()
+
+        print(
+            "Processing variable: " + name + " with shape: ",
+            data.shape,
+            "->",
+            data.dtype,
+        )
+
+        # header
+        str = name.encode("utf-8")
+        fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
+        for i in range(n_dims):
+            fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+        fout.write(str)
+
+        # data
+        data.tofile(fout)
+
+    # release memory
+    del model_part
 
 fout.close()