]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
gpt-2 : add Cerebras-GPT example
authorGeorgi Gerganov <redacted>
Thu, 30 Mar 2023 20:39:15 +0000 (23:39 +0300)
committerGeorgi Gerganov <redacted>
Thu, 30 Mar 2023 20:48:33 +0000 (23:48 +0300)
README.md
examples/gpt-2/README.md
examples/gpt-2/convert-cerebras-to-ggml.py [new file with mode: 0644]
examples/gpt-2/convert-h5-to-ggml.py
examples/gpt-2/main.cpp

index 8ad8d84e3f0b4ae14cb6619d7f195fe11e96d807..5c33068f1efe406392696749803049ca4dd37743 100644 (file)
--- a/README.md
+++ b/README.md
@@ -23,6 +23,7 @@ Some of the development is currently happening in the [llama.cpp](https://github
 - [X] Example of GPT-J inference [examples/gpt-j](https://github.com/ggerganov/ggml/tree/master/examples/gpt-j)
 - [X] Example of Whisper inference [examples/whisper](https://github.com/ggerganov/ggml/tree/master/examples/whisper)
 - [X] Support 4-bit integer quantization https://github.com/ggerganov/ggml/pull/27
+- [X] Example of Cerebras-GPT inference [examples/gpt-2](https://github.com/ggerganov/ggml/tree/master/examples/gpt-2)
 - [ ] Example of FLAN-T5 inference https://github.com/ggerganov/ggml/pull/12
 - [X] Example of LLaMA inference [llama.cpp](https://github.com/ggerganov/llama.cpp)
 - [ ] Example of RWKV inference
@@ -62,6 +63,11 @@ make -j4 gpt-2 gpt-j
 # Run the GPT-J 6B model (requires 12GB disk space and 16GB CPU RAM)
 ../examples/gpt-j/download-ggml-model.sh 6B
 ./bin/gpt-j -m models/gpt-j-6B/ggml-model.bin -p "This is an example"
+
+# Run the Cerebras-GPT 111M model
+# Download from: https://huggingface.co/cerebras
+python3 ./examples/gpt-2/convert-cerebras-to-ggml.py /path/to/Cerebras-GPT-111M/
+./bin/gpt-2 -m /path/to/Cerebras-GPT-111M/ggml-model-f16.bin -p "This is an example"
 ```
 
 The inference speeds that I get for the different models on my 32GB MacBook M1 Pro are as follows:
index ed61f861f678750e4ab947a4ddb979ba44e467e9..568bf435871b8589b61d9ac70046885625badd97 100644 (file)
@@ -4,7 +4,9 @@ This is a C++ example running GPT-2 inference using the [ggml](https://github.co
 
 The program runs on the CPU - no video card is required.
 
-The example supports the following models:
+The [Cerebras-GPT](https://huggingface.co/cerebras) models are also supported.
+
+The example supports the following GPT-2 models:
 
 | Model | Description  | Disk Size |
 | ---   | ---          | ---       |
@@ -22,6 +24,8 @@ Sample performance on MacBook M1 Pro:
 | GPT-2 |  774M |  23 ms |
 | GPT-2 | 1558M |  42 ms |
 
+*TODO: add tables for Cerebras-GPT models*
+
 Sample output:
 
 ```
@@ -70,7 +74,7 @@ main:  predict time =   506.40 ms / 5.06 ms per token
 main:    total time =   629.84 ms
 ```
 
-## Downloading and converting the original models
+## Downloading and converting the original models (GPT-2)
 
 You can download the original model files using the [download-model.sh](download-model.sh) Bash script. The models are
 in Tensorflow format, so in order to use them with ggml, you need to convert them to appropriate format. This is done
@@ -101,6 +105,19 @@ Run the convert-ckpt-to-ggml.py script to convert the model to ggml format.
 This conversion requires that you have python and Tensorflow installed on your computer. Still, if you want to avoid
 this, you can download the already converted ggml models as described below.
 
+## Downloading and converting the original models (Cerebras-GPT)
+
+Clone the respective repository from here: https://huggingface.co/cerebras
+
+Use the [convert-cerebras-to-ggml.py](convert-cerebras-to-ggml.py) script to convert the model to `ggml` format:
+
+```
+cd ggml/build
+git clone https://huggingface.co/cerebras/Cerebras-GPT-111M models/
+python ../examples/gpt-2/convert-cerebras-to-ggml.py models/Cerebras-GPT-111M/
+
+```
+
 ## Downloading the ggml model directly
 
 For convenience, I will be hosting the converted ggml model files in order to make it easier to run the examples. This
@@ -122,3 +139,20 @@ You can now use it like this:
 ```
 
 At some point, I might decide to stop hosting these models. So in that case, simply revert to the manual process above.
+
+## Quantizing the models
+
+You can also try to quantize the `ggml` models via 4-bit integer quantization.
+Keep in mind that for smaller models, this will render them completely useless.
+You generally want to quantize larger models.
+
+```
+# quantize GPT-2 F16 to Q4_0 (faster but less precise)
+./bin/gpt-2-quantize models/gpt-2-1558M/ggml-model-f16.bin models/gpt-2-1558M/ggml-model-q4_0.bin 2
+./bin/gpt-2 -m models/gpt-2-1558M/ggml-model-q4_0.bin -p "This is an example"
+
+# quantize Cerebras F16 to Q4_1 (slower but more precise)
+./bin/gpt-2-quantize models/Cerebras-GPT-6.7B/ggml-model-f16.bin models/Cerebras-GPT-6.7B/ggml-model-q4_1.bin 3
+./bin/gpt-2 -m models/Cerebras-GPT-6.7B/ggml-model-q4_1.bin -p "This is an example"
+
+```
diff --git a/examples/gpt-2/convert-cerebras-to-ggml.py b/examples/gpt-2/convert-cerebras-to-ggml.py
new file mode 100644 (file)
index 0000000..6f20a54
--- /dev/null
@@ -0,0 +1,183 @@
+# Convert Cerebras models to ggml format
+#
+# ref: https://www.cerebras.net/blog/cerebras-gpt-a-family-of-open-compute-efficient-large-language-models/
+#
+
+import sys
+import struct
+import json
+import torch
+import numpy as np
+import re
+
+from transformers import GPTJForCausalLM, AutoModelForCausalLM
+
+# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
+def bytes_to_unicode():
+    """
+    Returns list of utf-8 byte and a corresponding list of unicode strings.
+    The reversible bpe codes work on unicode strings.
+    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+    This is a signficant percentage of your normal, say, 32K bpe vocab.
+    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+    And avoids mapping to whitespace/control characters the bpe code barfs on.
+    """
+    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+    cs = bs[:]
+    n = 0
+    for b in range(2**8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2**8+n)
+            n += 1
+    cs = [chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+if len(sys.argv) < 2:
+    print("Usage: convert-h5-to-ggml.py dir-model [use-f32]\n")
+    sys.exit(1)
+
+# output in the same directory as the model
+dir_model = sys.argv[1]
+fname_out = sys.argv[1] + "/ggml-model-f16.bin"
+
+with open(dir_model + "/vocab.json", "r") as f:
+    encoder = json.load(f)
+
+with open(dir_model + "/config.json", "r") as f:
+    hparams = json.load(f)
+
+# use 16-bit or 32-bit floats
+use_f16 = True
+if len(sys.argv) > 2:
+    use_f16 = False
+    fname_out = sys.argv[1] + "/ggml-model-f32.bin"
+
+model = AutoModelForCausalLM.from_pretrained(dir_model, low_cpu_mem_usage=True)
+#print (model)
+
+list_vars = model.state_dict()
+#print (list_vars)
+
+print(hparams)
+
+fout = open(fname_out, "wb")
+
+fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
+fout.write(struct.pack("i", hparams["vocab_size"]))
+fout.write(struct.pack("i", hparams["n_positions"]))
+fout.write(struct.pack("i", hparams["n_embd"]))
+fout.write(struct.pack("i", hparams["n_head"]))
+fout.write(struct.pack("i", hparams["n_layer"]))
+fout.write(struct.pack("i", use_f16))
+
+byte_encoder = bytes_to_unicode()
+byte_decoder = {v:k for k, v in byte_encoder.items()}
+
+fout.write(struct.pack("i", len(encoder)))
+
+for key in encoder:
+    text = bytearray([byte_decoder[c] for c in key])
+    fout.write(struct.pack("i", len(text)))
+    fout.write(text)
+
+for name in list_vars.keys():
+    data = list_vars[name].squeeze().numpy()
+    print("Processing variable: " + name + " with shape: ", data.shape)
+
+    # rename headers to keep compatibility
+    if name == "transformer.ln_f.weight":
+        name = "model/ln_f/g"
+    elif name == "transformer.ln_f.bias":
+        name = "model/ln_f/b"
+    elif name == "transformer.wte.weight":
+        name = "model/wte"
+    elif name == "transformer.wpe.weight":
+        name = "model/wpe"
+    elif name == "lm_head.weight":
+        name = "model/lm_head"
+    elif re.match(r"transformer.h\.\d+\.ln_1\.weight", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/ln_1/g"
+    elif re.match(r"transformer.h\.\d+\.ln_1\.bias", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/ln_1/b"
+    elif re.match(r"transformer.h\.\d+\.attn\.c_attn\.weight", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/attn/c_attn/w"
+    elif re.match(r"transformer.h\.\d+\.attn\.c_attn\.bias", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/attn/c_attn/b"
+    elif re.match(r"transformer.h\.\d+\.attn\.c_proj\.weight", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/attn/c_proj/w"
+    elif re.match(r"transformer.h.\d+.attn.c_proj.bias", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/attn/c_proj/b"
+    elif re.match(r"transformer.h.\d+.ln_2.weight", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/ln_2/g"
+    elif re.match(r"transformer.h.\d+.ln_2.bias", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/ln_2/b"
+    elif re.match(r"transformer.h.\d+.mlp.c_fc.weight", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/mlp/c_fc/w"
+    elif re.match(r"transformer.h.\d+.mlp.c_fc.bias", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/mlp/c_fc/b"
+    elif re.match(r"transformer.h.\d+.mlp.c_proj.weight", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/mlp/c_proj/w"
+    elif re.match(r"transformer.h.\d+.mlp.c_proj.bias", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/mlp/c_proj/b"
+    else:
+        print("Unrecognized variable name. %s", name)
+
+    # we don't need these
+    if name.endswith("attn.masked_bias") or name.endswith(".attn.bias"):
+        print("  Skipping variable: " + name)
+        continue
+
+    n_dims = len(data.shape);
+
+    # ftype == 0 -> float32, ftype == 1 -> float16
+    ftype = 0;
+    if use_f16:
+        if (name == "model/wte" or name == "model/lm_head" or name[-2:] == "/g" or name[-2:] == "/w") and n_dims == 2:
+            print("  Converting to float16")
+            data = data.astype(np.float16)
+            ftype = 1
+        else:
+            print("  Converting to float32")
+            data = data.astype(np.float32)
+            ftype = 0
+
+    # for efficiency - transpose the projection matrices
+    # "model/h.*/attn/c_attn/w"
+    # "model/h.*/attn/c_proj/w"
+    # "model/h.*/mlp/c_fc/w"
+    # "model/h.*/mlp/c_proj/w"
+    if name[-14:] == "/attn/c_attn/w" or \
+       name[-14:] == "/attn/c_proj/w" or \
+       name[-11:] == "/mlp/c_fc/w" or \
+       name[-13:] == "/mlp/c_proj/w":
+        print("  Transposing")
+        data = data.transpose()
+
+    # header
+    str = name.encode('utf-8')
+    fout.write(struct.pack("iii", n_dims, len(str), ftype))
+    for i in range(n_dims):
+        fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+    fout.write(str);
+
+    # data
+    data.tofile(fout)
+
+fout.close()
+
+print("Done. Output file: " + fname_out)
+print("")
index 7ee3395748189b1c24fa7b4610522b4a2f639c28..4e86ce27e65db60d903b6a4a8e45b400b1326d45 100644 (file)
@@ -140,7 +140,7 @@ for name in list_vars.keys():
         name = "model/wte"
     elif name == "wpe.weight":
         name = "model/wpe"
-    elif re.match(r'h\.\d+\.ln_1\.weight', name):
+    elif re.match(r"h\.\d+\.ln_1\.weight", name):
         i = re.findall("\d+", name)[0]
         name = f"model/h{i}/ln_1/g"
     elif re.match(r"h\.\d+\.ln_1\.bias", name):
@@ -192,4 +192,4 @@ for name in list_vars.keys():
 fout.close()
 
 print("Done. Output file: " + fname_out)
-print("")
\ No newline at end of file
+print("")
index f371bc66735d821e2b2e139c44f62d0e9275c37d..d9feb6e973e3d66da3f13b29fa4ad1b2f3740deb 100644 (file)
@@ -53,8 +53,9 @@ struct gpt2_model {
     struct ggml_tensor * ln_f_g;
     struct ggml_tensor * ln_f_b;
 
-    struct ggml_tensor * wte; // position embedding
-    struct ggml_tensor * wpe; //    token embedding
+    struct ggml_tensor * wte;     // position embedding
+    struct ggml_tensor * wpe;     //    token embedding
+    struct ggml_tensor * lm_head; // language model head
 
     std::vector<gpt2_layer> layers;
 
@@ -165,6 +166,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
 
         ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype);         // wte
         ctx_size +=   n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
+        ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype);         // lm_head
 
         ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
         ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
@@ -220,15 +222,17 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
         model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
         model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
 
-        model.wte = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);
-        model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
+        model.wte     = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);
+        model.wpe     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
+        model.lm_head = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);
 
         // map by name
         model.tensors["model/ln_f/g"] = model.ln_f_g;
         model.tensors["model/ln_f/b"] = model.ln_f_b;
 
-        model.tensors["model/wte"] = model.wte;
-        model.tensors["model/wpe"] = model.wpe;
+        model.tensors["model/wte"]     = model.wte;
+        model.tensors["model/wpe"]     = model.wpe;
+        model.tensors["model/lm_head"] = model.lm_head;
 
         for (int i = 0; i < n_layer; ++i) {
             auto & layer = model.layers[i];
@@ -295,6 +299,8 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
     {
         size_t total_size = 0;
 
+        bool has_lm_head = false;
+
         while (true) {
             int32_t n_dims;
             int32_t length;
@@ -362,6 +368,15 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
 
             fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
 
+            // GPT-2 models share the WTE tensor as the LM head
+            if (name == "model/wte" && has_lm_head == false) {
+                memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));
+            }
+
+            if (name == "model/lm_head") {
+                has_lm_head = true;
+            }
+
             total_size += ggml_nbytes(tensor);
         }
 
@@ -656,9 +671,9 @@ bool gpt2_eval(
     }
 
     // inpL = WTE * inpL
-    // [ 768, 50257] - model.wte
+    // [ 768, 50257] - model.lm_head
     // [ 768, N]     - inpL
-    inpL = ggml_mul_mat(ctx0, model.wte, inpL);
+    inpL = ggml_mul_mat(ctx0, model.lm_head, inpL);
 
     // logits -> probs
     //inpL = ggml_soft_max(ctx0, inpL);