]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
examples : add StarCoder/SantaCoder sample inference (#146)
authorNouamane Tazi <redacted>
Sat, 13 May 2023 09:54:03 +0000 (11:54 +0200)
committerGitHub <redacted>
Sat, 13 May 2023 09:54:03 +0000 (12:54 +0300)
* init commit

* fix building starcoder

* gen work

* fix vocab

* santacoder mha

* .

* fix quantize

* offload_state_dict

* endoftext

* rename scripts

* fix main

* scripts

* update README

* quickfixes

examples/CMakeLists.txt
examples/starcoder/CMakeLists.txt [new file with mode: 0644]
examples/starcoder/README.md [new file with mode: 0644]
examples/starcoder/convert-hf-to-ggml.py [new file with mode: 0644]
examples/starcoder/main.cpp [new file with mode: 0644]
examples/starcoder/quantize.cpp [new file with mode: 0644]

index 6d07eae4707d09e5ac7321728cd1b357553ca02d..8a8b9f71ad418a36174260511c7ec570d7a19468 100644 (file)
@@ -11,3 +11,4 @@ add_subdirectory(whisper)
 add_subdirectory(mnist)
 add_subdirectory(gpt-neox)
 add_subdirectory(dolly-v2)
+add_subdirectory(starcoder)
diff --git a/examples/starcoder/CMakeLists.txt b/examples/starcoder/CMakeLists.txt
new file mode 100644 (file)
index 0000000..4c25b4d
--- /dev/null
@@ -0,0 +1,13 @@
+#
+# starcoder
+
+set(TEST_TARGET starcoder)
+add_executable(${TEST_TARGET} main.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
+
+#
+# starcoder-quantize
+
+set(TEST_TARGET starcoder-quantize)
+add_executable(${TEST_TARGET} quantize.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
diff --git a/examples/starcoder/README.md b/examples/starcoder/README.md
new file mode 100644 (file)
index 0000000..1947578
--- /dev/null
@@ -0,0 +1,112 @@
+# ðŸ’« StarCoder
+
+This is a C++ example running ðŸ’« StarCoder inference using the [ggml](https://github.com/ggerganov/ggml) library.
+
+The program runs on the CPU - no video card is required.
+
+The example supports the following ðŸ’« StarCoder models:
+
+- `bigcode/starcoder`
+- `bigcode/gpt_bigcode-santacoder` aka the smol StarCoder
+
+Sample performance on MacBook M1 Pro:
+
+TODO
+
+
+Sample output:
+
+```
+$ ./bin/starcoder -h
+usage: ./bin/starcoder [options]
+
+options:
+  -h, --help            show this help message and exit
+  -s SEED, --seed SEED  RNG seed (default: -1)
+  -t N, --threads N     number of threads to use during computation (default: 8)
+  -p PROMPT, --prompt PROMPT
+                        prompt to start generation with (default: random)
+  -n N, --n_predict N   number of tokens to predict (default: 200)
+  --top_k N             top-k sampling (default: 40)
+  --top_p N             top-p sampling (default: 0.9)
+  --temp N              temperature (default: 1.0)
+  -b N, --batch_size N  batch size for prompt processing (default: 8)
+  -m FNAME, --model FNAME
+                        model path (default: models/starcoder-117M/ggml-model.bin)
+
+$ ./bin/starcoder -m ../models/bigcode/gpt_bigcode-santacoder-ggml-q4_1.bin -p "def fibonnaci(" -t 4 --top_k 0 --top_p 0.95 --temp 0.2      
+main: seed = 1683881276
+gpt2_model_load: loading model from '../models/bigcode/gpt_bigcode-santacoder-ggml-q4_1.bin'
+gpt2_model_load: n_vocab = 49280
+gpt2_model_load: n_ctx   = 2048
+gpt2_model_load: n_embd  = 2048
+gpt2_model_load: n_head  = 16
+gpt2_model_load: n_layer = 24
+gpt2_model_load: ftype   = 3
+gpt2_model_load: ggml ctx size = 1794.90 MB
+gpt2_model_load: memory size =   768.00 MB, n_mem = 49152
+gpt2_model_load: model size  =  1026.83 MB
+main: prompt: 'def fibonnaci('
+main: number of tokens in prompt = 7, first 8 tokens: 563 24240 78 2658 64 2819 7 
+
+def fibonnaci(n):
+    if n == 0:
+        return 0
+    elif n == 1:
+        return 1
+    else:
+        return fibonacci(n-1) + fibonacci(n-2)
+
+print(fibo(10))
+
+main: mem per token =  9597928 bytes
+main:     load time =   480.43 ms
+main:   sample time =    26.21 ms
+main:  predict time =  3987.95 ms / 19.36 ms per token
+main:    total time =  4580.56 ms
+```
+
+## Quick start
+```bash
+git clone https://github.com/ggerganov/ggml
+cd ggml
+
+# Convert HF model to ggml
+python examples/starcoder/convert-hf-to-ggml.py bigcode/gpt_bigcode-santacoder
+
+# Build ggml + examples
+mkdir build && cd build
+cmake .. && make -j4 starcoder starcoder-quantize
+
+# quantize the model
+./bin/starcoder-quantize ../models/bigcode/gpt_bigcode-santacoder-ggml.bin ../models/bigcode/gpt_bigcode-santacoder-ggml-q4_1.bin 3
+
+# run inference
+./bin/starcoder -m ../models/bigcode/gpt_bigcode-santacoder-ggml-q4_1.bin -p "def fibonnaci(" --top_k 0 --top_p 0.95 --temp 0.2
+```
+
+
+## Downloading and converting the original models (💫 StarCoder)
+
+You can download the original model and convert it to `ggml` format using the script `convert-hf-to-ggml.py`:
+
+```
+# Convert HF model to ggml
+python examples/starcoder/convert-hf-to-ggml.py bigcode/gpt_bigcode-santacoder
+```
+
+This conversion requires that you have python and Transformers installed on your computer.
+
+## Quantizing the models
+
+You can also try to quantize the `ggml` models via 4-bit integer quantization.
+
+```
+# quantize the model
+./bin/starcoder-quantize ../models/bigcode/gpt_bigcode-santacoder-ggml.bin ../models/bigcode/gpt_bigcode-santacoder-ggml-q4_1.bin 3
+```
+
+| Model | Original size | Quantized size | Quantization type |
+| --- | --- | --- | --- |
+| `bigcode/gpt_bigcode-santacoder` | 5396.45 MB | 1026.83 MB | 4-bit integer (q4_1) |
+| `bigcode/starcoder` | 71628.23 MB | 13596.23 MB | 4-bit integer (q4_1) |
\ No newline at end of file
diff --git a/examples/starcoder/convert-hf-to-ggml.py b/examples/starcoder/convert-hf-to-ggml.py
new file mode 100644 (file)
index 0000000..c3ddfe7
--- /dev/null
@@ -0,0 +1,212 @@
+# Convert HF models to ggml format
+#
+
+import sys
+import struct
+import json
+import torch
+import numpy as np
+import re
+import os
+
+from transformers import AutoModelForCausalLM
+from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BloomForCausalLM
+
+# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
+def bytes_to_unicode():
+    """
+    Returns list of utf-8 byte and a corresponding list of unicode strings.
+    The reversible bpe codes work on unicode strings.
+    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+    This is a signficant percentage of your normal, say, 32K bpe vocab.
+    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+    And avoids mapping to whitespace/control characters the bpe code barfs on.
+    """
+    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+    cs = bs[:]
+    n = 0
+    for b in range(2**8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2**8+n)
+            n += 1
+    cs = [chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+if len(sys.argv) < 2:
+    print("Usage: python convert-hf-to-ggml.py hf-model-name [use-f32]")
+    print("Example: python convert-hf-to-ggml.py bigcode/gpt_bigcode-santacoder")
+    print("Example: python convert-hf-to-ggml.py bigcode/starcoder")
+    sys.exit(1)
+
+model_name = sys.argv[1].strip()
+fname_out = "models/" + sys.argv[1].strip() + "-ggml.bin"
+os.makedirs(os.path.dirname(fname_out), exist_ok=True)
+
+
+
+# use 16-bit or 32-bit floats
+use_f16 = True
+if len(sys.argv) > 2:
+    use_f16 = False
+
+print("Loading model: ", model_name)
+tokenizer = AutoTokenizer.from_pretrained(model_name)
+config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
+hparams = config.to_dict()
+model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.float16 if use_f16 else torch.float32, low_cpu_mem_usage=True, trust_remote_code=True, offload_state_dict=True)
+print("Model loaded: ", model_name)
+
+#print (model)
+
+list_vars = model.state_dict()
+#print (list_vars)
+
+encoder = tokenizer.vocab
+# Add added_tokens (special tokens) to the encoder
+encoder.update(tokenizer.get_added_vocab())
+print(hparams)
+
+print("Saving ggml model to: ", fname_out)
+fout = open(fname_out, "wb")
+
+fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
+vocab_size = hparams["vocab_size"]
+fout.write(struct.pack("i", vocab_size))
+# fout.write(struct.pack("i", len(encoder)))
+fout.write(struct.pack("i", hparams["n_positions"]))
+fout.write(struct.pack("i", hparams["n_embd"]))
+fout.write(struct.pack("i", hparams["n_head"]))
+fout.write(struct.pack("i", hparams["n_layer"]))
+fout.write(struct.pack("i", use_f16))
+
+byte_encoder = bytes_to_unicode()
+byte_decoder = {v:k for k, v in byte_encoder.items()}
+
+fout.write(struct.pack("i", vocab_size))
+
+counter = 0
+# sort by value
+for key in sorted(encoder, key=encoder.get):
+    text = bytearray([byte_decoder[c] for c in key])
+    fout.write(struct.pack("i", len(text)))
+    fout.write(text)
+    counter += 1
+
+# TODO: Repeat last token until vocab_size
+while counter < vocab_size:
+    fout.write(struct.pack("i", len(text)))
+    fout.write(text)
+    counter += 1
+# assert counter == config.vocab_size
+
+for name in list_vars.keys():
+    data = list_vars[name].squeeze().numpy()
+    print("Processing variable: " + name + " with shape: ", data.shape)
+
+    # rename headers to keep compatibility
+    if name == "transformer.ln_f.weight":
+        name = "model/ln_f/g"
+    elif name == "transformer.ln_f.bias":
+        name = "model/ln_f/b"
+    elif name == "transformer.wte.weight":
+        name = "model/wte"
+    elif name == "transformer.wpe.weight":
+        name = "model/wpe"
+    elif name == "lm_head.weight":
+        name = "model/lm_head"
+    elif re.match(r"transformer.h\.\d+\.ln_1\.weight", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/ln_1/g"
+    elif re.match(r"transformer.h\.\d+\.ln_1\.bias", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/ln_1/b"
+    elif re.match(r"transformer.h\.\d+\.attn\.c_attn\.weight", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/attn/c_attn/w"
+    elif re.match(r"transformer.h\.\d+\.attn\.c_attn\.bias", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/attn/c_attn/b"
+    elif re.match(r"transformer.h\.\d+\.attn\.c_proj\.weight", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/attn/c_proj/w"
+    elif re.match(r"transformer.h.\d+.attn.c_proj.bias", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/attn/c_proj/b"
+    elif re.match(r"transformer.h.\d+.ln_2.weight", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/ln_2/g"
+    elif re.match(r"transformer.h.\d+.ln_2.bias", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/ln_2/b"
+    elif re.match(r"transformer.h.\d+.mlp.c_fc.weight", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/mlp/c_fc/w"
+    elif re.match(r"transformer.h.\d+.mlp.c_fc.bias", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/mlp/c_fc/b"
+    elif re.match(r"transformer.h.\d+.mlp.c_proj.weight", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/mlp/c_proj/w"
+    elif re.match(r"transformer.h.\d+.mlp.c_proj.bias", name):
+        i = re.findall("\d+", name)[0]
+        name = f"model/h{i}/mlp/c_proj/b"
+    else:
+        print("Unrecognized variable name. %s", name)
+
+    # we don't need these
+    if name.endswith("attn.masked_bias") or name.endswith(".attn.bias"):
+        print("  Skipping variable: " + name)
+        continue
+
+    n_dims = len(data.shape);
+
+    # ftype == 0 -> float32, ftype == 1 -> float16
+    ftype = 0;
+    if use_f16:
+        if (name == "model/wte" or name == "model/lm_head" or name[-2:] == "/g" or name[-2:] == "/w") and n_dims == 2:
+            print("  Converting to float16")
+            data = data.astype(np.float16)
+            ftype = 1
+        else:
+            print("  Converting to float32")
+            data = data.astype(np.float32)
+            ftype = 0
+
+    "model/h.*/attn/c_attn/w"
+    "model/h.*/attn/c_proj/w"
+    "model/h.*/mlp/c_fc/w"
+    "model/h.*/mlp/c_proj/w"
+    if name[-14:] == "/attn/c_attn/w" or name[-14:] == "/attn/c_attn/b":
+        print("  Duplicate K,V heads to use MHA instead of MQA")
+
+        embed_dim = hparams["n_embd"]
+        head_dim = embed_dim // hparams["n_head"]
+
+        # ((n_heads + 2) * head_dim, hidden_dim) -> (3 * n_heads * head_dim, hidden_dim)
+        q, k ,v = np.split(data, (hparams["n_head"] * head_dim, (hparams["n_head"] + 1) * head_dim), axis=0)
+        # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
+        if len(k.shape) == 2:
+            k = np.tile(k, (hparams["n_head"], 1))
+            v = np.tile(v, (hparams["n_head"], 1))
+        elif len(k.shape) == 1:
+            k = np.tile(k, (hparams["n_head"]))
+            v = np.tile(v, (hparams["n_head"]))
+        # concat q, k, v along the first axis (n_heads * head_dim, hidden_dim) -> (3 * n_heads * head_dim, hidden_dim)
+        data = np.concatenate((q, k, v), axis=0)
+
+    # header
+    str = name.encode('utf-8')
+    fout.write(struct.pack("iii", n_dims, len(str), ftype))
+    for i in range(n_dims):
+        fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+    fout.write(str);
+
+    # data
+    data.tofile(fout)
+
+fout.close()
+
+print("Done. Output file: " + fname_out)
+print("")
diff --git a/examples/starcoder/main.cpp b/examples/starcoder/main.cpp
new file mode 100644 (file)
index 0000000..af9151c
--- /dev/null
@@ -0,0 +1,847 @@
+#include "ggml/ggml.h"
+
+#include "common.h"
+#include "common-ggml.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <string>
+#include <vector>
+#include <iostream>
+#include <unistd.h>
+
+// default hparams (GPT-2 117M)
+// https://huggingface.co/bigcode/gpt_bigcode-santacoder/blob/main/config.json
+struct gpt2_hparams {
+    int32_t n_vocab = 49280;
+    int32_t n_ctx   = 2048;
+    int32_t n_embd  = 2048;
+    int32_t n_head  = 16;
+    int32_t n_layer = 24;
+    int32_t ftype   = 1;
+};
+
+struct gpt2_layer {
+    // normalization
+    struct ggml_tensor * ln_1_g;
+    struct ggml_tensor * ln_1_b;
+
+    struct ggml_tensor * ln_2_g;
+    struct ggml_tensor * ln_2_b;
+
+    // attention
+    struct ggml_tensor * c_attn_attn_w;
+    struct ggml_tensor * c_attn_attn_b;
+
+    struct ggml_tensor * c_attn_proj_w;
+    struct ggml_tensor * c_attn_proj_b;
+
+    // mlp
+    struct ggml_tensor * c_mlp_fc_w;
+    struct ggml_tensor * c_mlp_fc_b;
+
+    struct ggml_tensor * c_mlp_proj_w;
+    struct ggml_tensor * c_mlp_proj_b;
+};
+
+struct gpt2_model {
+    gpt2_hparams hparams;
+
+    // normalization
+    struct ggml_tensor * ln_f_g;
+    struct ggml_tensor * ln_f_b;
+
+    struct ggml_tensor * wte;     // position embedding
+    struct ggml_tensor * wpe;     //    token embedding
+    struct ggml_tensor * lm_head; // language model head
+
+    std::vector<gpt2_layer> layers;
+
+    // key + value memory
+    struct ggml_tensor * memory_k;
+    struct ggml_tensor * memory_v;
+
+    //
+    struct ggml_context * ctx;
+    std::map<std::string, struct ggml_tensor *> tensors;
+};
+
+// load the model's weights from a file
+bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab) {
+    printf("%s: loading model from '%s'\n", __func__, fname.c_str());
+
+    auto fin = std::ifstream(fname, std::ios::binary);
+    if (!fin) {
+        fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
+        return false;
+    }
+
+    // verify magic
+    {
+        uint32_t magic;
+        fin.read((char *) &magic, sizeof(magic));
+        if (magic != 0x67676d6c) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
+            return false;
+        }
+    }
+
+    // load hparams
+    {
+        auto & hparams = model.hparams;
+
+        fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+        fin.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));
+        fin.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));
+        fin.read((char *) &hparams.n_head,  sizeof(hparams.n_head));
+        fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+        fin.read((char *) &hparams.ftype,   sizeof(hparams.ftype));
+
+        printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
+        printf("%s: n_ctx   = %d\n", __func__, hparams.n_ctx);
+        printf("%s: n_embd  = %d\n", __func__, hparams.n_embd);
+        printf("%s: n_head  = %d\n", __func__, hparams.n_head);
+        printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
+        printf("%s: ftype   = %d\n", __func__, hparams.ftype);
+    }
+
+    // load vocab
+    {
+        int32_t n_vocab = 0;
+        fin.read((char *) &n_vocab, sizeof(n_vocab));
+
+        if (n_vocab != model.hparams.n_vocab) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
+                    __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
+            return false;
+        }
+
+        std::string word;
+        for (int i = 0; i < n_vocab; i++) {
+            uint32_t len;
+            fin.read((char *) &len, sizeof(len));
+
+            word.resize(len);
+            fin.read((char *) word.data(), len);
+
+            vocab.token_to_id[word] = i;
+            vocab.id_to_token[i] = word;
+
+            // if (i < 10) fprintf(stderr, "%.s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
+        }
+    }
+
+    // for the big tensors, we have the option to store the data in 16-bit floats or quantized
+    // in order to save memory and also to speed up the computation
+    ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
+    if (wtype == GGML_TYPE_COUNT) {
+        fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
+                __func__, fname.c_str(), model.hparams.ftype);
+        return false;
+    }
+
+    auto & ctx = model.ctx;
+
+    size_t ctx_size = 0;
+
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+        const int n_vocab = hparams.n_vocab;
+
+        const int head_dim = n_embd / hparams.n_head;
+        const int kv_heads = hparams.n_head; // 1 if MQA else hparams.n_head
+        const int kv_dim   = kv_heads * head_dim;
+
+        ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
+        ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
+
+        ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype);         // wte
+        ctx_size +=   n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
+        ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype);         // lm_head
+
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
+
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
+        ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
+
+        ctx_size += n_layer*((n_embd + 2*kv_dim)*n_embd*ggml_type_sizef(wtype));         // c_attn_attn_w // TODO:
+        ctx_size += n_layer*(       (n_embd + 2*kv_dim)*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
+
+        ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype));           // c_attn_proj_w
+        ctx_size += n_layer*(       n_embd*ggml_type_sizef(GGML_TYPE_F32));   // c_attn_proj_b
+
+        ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype));         // c_mlp_fc_w
+        ctx_size += n_layer*(       4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
+
+        ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype));         // c_mlp_proj_w
+        ctx_size += n_layer*(         n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
+
+        ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
+        ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
+
+        ctx_size += (6 + 12*n_layer)*256; // object overhead
+
+        printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
+    }
+
+    // create the ggml context
+    {
+        struct ggml_init_params params = {
+            .mem_size   = ctx_size,
+            .mem_buffer = NULL,
+            .no_alloc   = false,
+        };
+
+        model.ctx = ggml_init(params);
+        if (!model.ctx) {
+            fprintf(stderr, "%s: ggml_init() failed\n", __func__);
+            return false;
+        }
+    }
+
+    // prepare memory for the weights
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+        const int n_vocab = hparams.n_vocab;
+
+        const int head_dim = n_embd / hparams.n_head;
+        const int kv_heads = hparams.n_head; // 1 if MQA else hparams.n_head
+        const int kv_dim   = kv_heads * head_dim;
+
+        model.layers.resize(n_layer);
+
+        model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+        model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+
+        model.wte     = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);
+        model.wpe     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
+        model.lm_head = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);
+
+        // map by name
+        model.tensors["model/ln_f/g"] = model.ln_f_g;
+        model.tensors["model/ln_f/b"] = model.ln_f_b;
+
+        model.tensors["model/wte"]     = model.wte;
+        model.tensors["model/wpe"]     = model.wpe;
+        model.tensors["model/lm_head"] = model.lm_head;
+
+        for (int i = 0; i < n_layer; ++i) {
+            auto & layer = model.layers[i];
+
+            layer.ln_1_g        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.ln_1_b        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            layer.ln_2_g        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.ln_2_b        = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype,           n_embd, n_embd + 2*kv_dim);
+            layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd + 2*kv_dim);
+
+            layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype,           n_embd, n_embd);
+            layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            layer.c_mlp_fc_w    = ggml_new_tensor_2d(ctx, wtype,           n_embd, 4*n_embd); //TODO: 4*n_embd = config.n_inner
+            layer.c_mlp_fc_b    = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
+
+            layer.c_mlp_proj_w  = ggml_new_tensor_2d(ctx, wtype,         4*n_embd, n_embd);
+            layer.c_mlp_proj_b  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+
+            // map by name
+            model.tensors["model/h" + std::to_string(i) + "/ln_1/g"]        = layer.ln_1_g;
+            model.tensors["model/h" + std::to_string(i) + "/ln_1/b"]        = layer.ln_1_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/ln_2/g"]        = layer.ln_2_g;
+            model.tensors["model/h" + std::to_string(i) + "/ln_2/b"]        = layer.ln_2_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/w"] = layer.c_attn_attn_w;
+            model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/b"] = layer.c_attn_attn_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/w"] = layer.c_attn_proj_w;
+            model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/b"] = layer.c_attn_proj_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"]    = layer.c_mlp_fc_w;
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"]    = layer.c_mlp_fc_b;
+
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"]  = layer.c_mlp_proj_w;
+            model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"]  = layer.c_mlp_proj_b;
+        }
+    }
+
+    // key + value memory
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd  = hparams.n_embd;
+        const int n_layer = hparams.n_layer;
+        const int n_ctx   = hparams.n_ctx;
+
+        const int n_mem      = n_layer*n_ctx;
+        const int n_elements = n_embd*n_mem;
+
+        model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+        model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
+
+        const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
+
+        printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
+    }
+
+    // load weights
+    {
+        size_t total_size = 0;
+
+        bool has_lm_head = false;
+
+        while (true) {
+            int32_t n_dims;
+            int32_t length;
+            int32_t ttype;
+
+            fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+            fin.read(reinterpret_cast<char *>(&length), sizeof(length));
+            fin.read(reinterpret_cast<char *>(&ttype),  sizeof(ttype));
+
+            if (fin.eof()) {
+                break;
+            }
+
+            int32_t nelements = 1;
+            int32_t ne[2] = { 1, 1 };
+            for (int i = 0; i < n_dims; ++i) {
+                fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+                nelements *= ne[i];
+            }
+
+            std::string name(length, 0);
+            fin.read(&name[0], length);
+
+            if (model.tensors.find(name.data()) == model.tensors.end()) {
+                fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
+                return false;
+            }
+
+            auto tensor = model.tensors[name.data()];
+            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
+                fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
+                        __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
+                return false;
+            }
+            if (ggml_nelements(tensor) != nelements) {
+                fprintf(stderr, "%s: tensor '%s' has wrong size in model file. got %d, expected %d\n",
+                        __func__, name.data(), (int) ggml_nelements(tensor), nelements);
+                return false;
+            }
+
+            // for debugging
+            if (0) {
+                printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
+            }
+
+            const size_t bpe = ggml_type_size(ggml_type(ttype));
+
+            if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
+                fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+                        __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
+                return false;
+            }
+
+            fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
+
+            // GPT-2 models share the WTE tensor as the LM head
+            if (name == "model/wte" && has_lm_head == false) {
+                memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));
+            }
+
+            if (name == "model/lm_head") {
+                has_lm_head = true;
+            }
+
+            total_size += ggml_nbytes(tensor);
+        }
+
+        printf("%s: model size  = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
+    }
+
+    fin.close();
+
+    return true;
+}
+
+// evaluate the transformer
+//
+//   - model:     the model
+//   - n_threads: number of threads to use
+//   - n_past:    the context size so far
+//   - embd_inp:  the embeddings of the tokens in the context
+//   - embd_w:    the predicted logits for the next token
+//
+bool gpt2_eval(
+        const gpt2_model & model,
+        const int n_threads,
+        const int n_past,
+        const std::vector<gpt_vocab::id> & embd_inp,
+              std::vector<float>         & embd_w,
+              size_t                     & mem_per_token) {
+    const int N = embd_inp.size();
+
+    const auto & hparams = model.hparams;
+
+    const int n_embd  = hparams.n_embd;
+    const int n_layer = hparams.n_layer;
+    const int n_ctx   = hparams.n_ctx;
+    const int n_head  = hparams.n_head;
+    const int n_vocab = hparams.n_vocab;
+
+    static size_t buf_size = 256u*1024*1024;
+    static void * buf = malloc(buf_size);
+
+    if (mem_per_token > 0 && mem_per_token*N > buf_size) {
+        const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
+        //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
+
+        // reallocate
+        buf_size = buf_size_new;
+        buf = realloc(buf, buf_size);
+        if (buf == nullptr) {
+            fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
+            return false;
+        }
+    }
+
+    struct ggml_init_params params = {
+        .mem_size   = buf_size,
+        .mem_buffer = buf,
+        .no_alloc   = false,
+    };
+
+    struct ggml_context * ctx0 = ggml_init(params);
+    struct ggml_cgraph gf = {};
+    gf.n_threads = n_threads;
+
+    struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
+
+    struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    for (int i = 0; i < N; ++i) {
+        ((int32_t *) position->data)[i] = n_past + i;
+    }
+
+    // wte + wpe
+    struct ggml_tensor * inpL =
+        ggml_add(ctx0,
+                ggml_get_rows(ctx0, model.wte, embd),
+                ggml_get_rows(ctx0, model.wpe, position));
+
+    for (int il = 0; il < n_layer; ++il) {
+        struct ggml_tensor * cur;
+
+        // norm
+        {
+            // [ 768, N]
+            cur = ggml_norm(ctx0, inpL);
+
+            // cur = ln_1_g*cur + ln_1_b
+            // [ 768, N]
+            cur = ggml_add(ctx0,
+                    ggml_mul(ctx0,
+                        ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
+                        cur),
+                    ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
+        }
+
+        // attn
+        // [2304, 768] - model.layers[il].c_attn_attn_w
+        // [2304,   1] - model.layers[il].c_attn_attn_b
+        // [ 768,   N] - cur (in)
+        // [2304,   N] - cur (out)
+        //
+        // cur = attn_w*cur + attn_b
+        // [2304, N]
+        {
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].c_attn_attn_w,
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur),
+                    cur);
+        }
+
+        // self-attention
+        {
+            struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
+            struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
+            struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
+
+            // store key and value to memory
+            if (N >= 1) {
+                struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
+                struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
+
+                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
+            }
+
+            // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
+            // [64, N, 12]
+            struct ggml_tensor * Q =
+                ggml_permute(ctx0,
+                        ggml_cpy(ctx0,
+                            Qcur,
+                            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
+                        0, 2, 1, 3);
+
+            // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
+            // [64, n_past + N, 12]
+            struct ggml_tensor * K =
+                ggml_permute(ctx0,
+                        ggml_reshape_3d(ctx0,
+                            ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
+                            n_embd/n_head, n_head, n_past + N),
+                        0, 2, 1, 3); //TODO: need to be tiled 
+
+            // GG: flash attention
+            //struct ggml_tensor * V =
+            //    ggml_cpy(ctx0,
+            //            ggml_permute(ctx0,
+            //                ggml_reshape_3d(ctx0,
+            //                    ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
+            //                    n_embd/n_head, n_head, n_past + N),
+            //                1, 2, 0, 3),
+            //            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
+
+            //struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
+
+            // K * Q
+            // [n_past + N, N, 12]
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); //TODO: check if it broadcasts
+
+            // KQ_scaled = KQ / sqrt(n_embd/n_head)
+            // [n_past + N, N, 12]
+            struct ggml_tensor * KQ_scaled =
+                ggml_scale(ctx0,
+                        KQ,
+                        ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
+                        );
+
+            // KQ_masked = mask_past(KQ_scaled)
+            // [n_past + N, N, 12]
+            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
+
+            // KQ = soft_max(KQ_masked)
+            // [n_past + N, N, 12]
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+
+            // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
+            // [n_past + N, 64, 12]
+            struct ggml_tensor * V_trans =
+                ggml_cpy(ctx0,
+                        ggml_permute(ctx0,
+                            ggml_reshape_3d(ctx0,
+                                ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
+                                n_embd/n_head, n_head, n_past + N),
+                            1, 2, 0, 3),
+                        ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
+
+            // KQV = transpose(V) * KQ_soft_max
+            // [64, N, 12]
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+
+            // KQV_merged = KQV.permute(0, 2, 1, 3)
+            // [64, 12, N]
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+            // cur = KQV_merged.contiguous().view(n_embd, N)
+            // [768, N]
+            cur = ggml_cpy(ctx0,
+                    KQV_merged,
+                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+        }
+
+        // projection
+        // [ 768, 768] - model.layers[il].c_attn_proj_w
+        // [ 768,   1] - model.layers[il].c_attn_proj_b
+        // [ 768,   N] - cur (in)
+        // [ 768,   N] - cur (out)
+        //
+        // cur = proj_w*cur + proj_b
+        // [768, N]
+        {
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].c_attn_proj_w,
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur),
+                    cur);
+        }
+
+        // add the input
+        cur = ggml_add(ctx0, cur, inpL);
+
+        struct ggml_tensor * inpFF = cur;
+
+        // feed-forward network
+        {
+            // norm
+            {
+                cur = ggml_norm(ctx0, inpFF);
+
+                // cur = ln_2_g*cur + ln_2_b
+                // [ 768, N]
+                cur = ggml_add(ctx0,
+                        ggml_mul(ctx0,
+                            ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),
+                            cur),
+                        ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
+            }
+
+            // fully connected
+            // [3072, 768] - model.layers[il].c_mlp_fc_w
+            // [3072,   1] - model.layers[il].c_mlp_fc_b
+            // [ 768,   N] - cur (in)
+            // [3072,   N] - cur (out)
+            //
+            // cur = fc_w*cur + fc_b
+            // [3072, N]
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].c_mlp_fc_w,
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),
+                    cur);
+
+            // GELU activation
+            // [3072, N]
+            cur = ggml_gelu(ctx0, cur);
+
+            // projection
+            // [ 768, 3072] - model.layers[il].c_mlp_proj_w
+            // [ 768,    1] - model.layers[il].c_mlp_proj_b
+            // [3072,    N] - cur (in)
+            // [ 768,    N] - cur (out)
+            //
+            // cur = proj_w*cur + proj_b
+            // [768, N]
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].c_mlp_proj_w,
+                    cur);
+
+            cur = ggml_add(ctx0,
+                    ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),
+                    cur);
+        }
+
+        // input for next layer
+        inpL = ggml_add(ctx0, cur, inpFF);
+    }
+
+    // norm
+    {
+        // [ 768, N]
+        inpL = ggml_norm(ctx0, inpL);
+
+        // inpL = ln_f_g*inpL + ln_f_b
+        // [ 768, N]
+        inpL = ggml_add(ctx0,
+                ggml_mul(ctx0,
+                    ggml_repeat(ctx0, model.ln_f_g, inpL),
+                    inpL),
+                ggml_repeat(ctx0, model.ln_f_b, inpL));
+    }
+
+    // inpL = WTE * inpL
+    // [ 768, 50257] - model.lm_head
+    // [ 768, N]     - inpL
+    inpL = ggml_mul_mat(ctx0, model.lm_head, inpL);
+
+    // logits -> probs
+    //inpL = ggml_soft_max(ctx0, inpL);
+
+    // run the computation
+    ggml_build_forward_expand(&gf, inpL);
+    ggml_graph_compute       (ctx0, &gf);
+
+    //if (n_past%100 == 0) {
+    //    ggml_graph_print   (&gf);
+    //    ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
+    //}
+
+    //embd_w.resize(n_vocab*N);
+    //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
+
+    // return result just for the last token
+    embd_w.resize(n_vocab);
+    memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
+
+    if (mem_per_token == 0) {
+        mem_per_token = ggml_used_mem(ctx0)/N;
+    }
+    //printf("used_mem = %zu\n", ggml_used_mem(ctx0));
+
+    ggml_free(ctx0);
+
+    return true;
+}
+
+int main(int argc, char ** argv) {
+    const int64_t t_main_start_us = ggml_time_us();
+
+    gpt_params params;
+    params.model = "models/gpt-2-117M/ggml-model.bin";
+
+    if (gpt_params_parse(argc, argv, params) == false) {
+        return 1;
+    }
+
+    if (params.seed < 0) {
+        params.seed = time(NULL);
+    }
+
+    printf("%s: seed = %d\n", __func__, params.seed);
+
+    std::mt19937 rng(params.seed);
+    if (params.prompt.empty()) {
+        if( !isatty(STDIN_FILENO) ){
+            std::string line;
+            while( std::getline(std::cin, line) ){
+                params.prompt = params.prompt + "\n" + line;
+            }
+        } else {
+            params.prompt = gpt_random_prompt(rng);
+        }
+    }
+
+    int64_t t_load_us = 0;
+
+    gpt_vocab vocab;
+    gpt2_model model;
+
+    // load the model
+    {
+        const int64_t t_start_us = ggml_time_us();
+
+        if (!gpt2_model_load(params.model, model, vocab)) {
+            fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
+            return 1;
+        }
+
+        t_load_us = ggml_time_us() - t_start_us;
+    }
+
+    int n_past = 0;
+
+    int64_t t_sample_us  = 0;
+    int64_t t_predict_us = 0;
+
+    std::vector<float> logits;
+
+    // tokenize the prompt
+    std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
+
+    params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
+
+    printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
+    printf("%s: number of tokens in prompt = %zu, first 8 tokens: ", __func__, embd_inp.size());
+    for (int i = 0; i < std::min(8, (int) embd_inp.size()); i++) {
+        printf("%d ", embd_inp[i]);
+    }
+    printf("\n\n");
+
+    // submit the input prompt token-by-token
+    // this reduces the memory usage during inference, at the cost of a bit of speed at the beginning
+    std::vector<gpt_vocab::id> embd;
+
+    // determine the required inference memory per token:
+    size_t mem_per_token = 0;
+    gpt2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
+
+    for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
+        // predict
+        if (embd.size() > 0) {
+            const int64_t t_start_us = ggml_time_us();
+
+            if (!gpt2_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
+                printf("Failed to predict\n");
+                return 1;
+            }
+
+            t_predict_us += ggml_time_us() - t_start_us;
+        }
+
+        n_past += embd.size();
+        embd.clear();
+
+        if (i >= embd_inp.size()) {
+            // sample next token
+            const int   top_k = params.top_k;
+            const float top_p = params.top_p;
+            const float temp  = params.temp;
+
+            const int n_vocab = model.hparams.n_vocab;
+
+            gpt_vocab::id id = 0;
+
+            {
+                const int64_t t_start_sample_us = ggml_time_us();
+
+                id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
+
+                t_sample_us += ggml_time_us() - t_start_sample_us;
+            }
+
+            // add it to the context
+            embd.push_back(id);
+        } else {
+            // if here, it means we are still processing the input prompt
+            for (int k = i; k < embd_inp.size(); k++) {
+                embd.push_back(embd_inp[k]);
+                if (embd.size() >= params.n_batch) {
+                    break;
+                }
+            }
+            i += embd.size() - 1;
+        }
+
+        // display text
+        for (auto id : embd) {
+            printf("%s", vocab.id_to_token[id].c_str());
+        }
+        fflush(stdout);
+
+        // end of text token
+        if (embd.back() == 0) { //TODO: this is only for starcoder
+            break;
+        }
+    }
+
+    // report timing
+    {
+        const int64_t t_main_end_us = ggml_time_us();
+
+        printf("\n\n");
+        printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
+        printf("%s:     load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
+        printf("%s:   sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
+        printf("%s:  predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);
+        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
+    }
+
+    ggml_free(model.ctx);
+
+    return 0;
+}
diff --git a/examples/starcoder/quantize.cpp b/examples/starcoder/quantize.cpp
new file mode 100644 (file)
index 0000000..2ed612d
--- /dev/null
@@ -0,0 +1,178 @@
+#include "ggml/ggml.h"
+
+#include "common.h"
+#include "common-ggml.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <string>
+#include <vector>
+#include <regex>
+
+// default hparams (GPT-2 117M)
+struct gpt2_hparams {
+    int32_t n_vocab = 49280;
+    int32_t n_ctx   = 2048;
+    int32_t n_embd  = 2048;
+    int32_t n_head  = 16;
+    int32_t n_layer = 24;
+    int32_t f16     = 1;
+};
+
+// quantize a model
+bool gpt2_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) {
+    gpt_vocab vocab;
+
+    printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
+
+    auto finp = std::ifstream(fname_inp, std::ios::binary);
+    if (!finp) {
+        fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str());
+        return false;
+    }
+
+    auto fout = std::ofstream(fname_out, std::ios::binary);
+    if (!fout) {
+        fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
+        return false;
+    }
+
+    // verify magic
+    {
+        uint32_t magic;
+        finp.read((char *) &magic, sizeof(magic));
+        if (magic != 0x67676d6c) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
+            return false;
+        }
+
+        fout.write((char *) &magic, sizeof(magic));
+    }
+
+    gpt2_hparams hparams;
+
+    // load hparams
+    {
+        finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+        finp.read((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));
+        finp.read((char *) &hparams.n_embd,  sizeof(hparams.n_embd));
+        finp.read((char *) &hparams.n_head,  sizeof(hparams.n_head));
+        finp.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+        finp.read((char *) &hparams.f16,     sizeof(hparams.f16));
+
+        printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
+        printf("%s: n_ctx   = %d\n", __func__, hparams.n_ctx);
+        printf("%s: n_embd  = %d\n", __func__, hparams.n_embd);
+        printf("%s: n_head  = %d\n", __func__, hparams.n_head);
+        printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
+        printf("%s: f16     = %d\n", __func__, hparams.f16);
+
+        fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+        fout.write((char *) &hparams.n_ctx,   sizeof(hparams.n_ctx));
+        fout.write((char *) &hparams.n_embd,  sizeof(hparams.n_embd));
+        fout.write((char *) &hparams.n_head,  sizeof(hparams.n_head));
+        fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+        fout.write((char *) &ftype,           sizeof(hparams.f16));
+    }
+
+    // load vocab
+    {
+        int32_t n_vocab = 0;
+        finp.read ((char *) &n_vocab, sizeof(n_vocab));
+        fout.write((char *) &n_vocab, sizeof(n_vocab));
+
+        if (n_vocab != hparams.n_vocab) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
+                    __func__, fname_inp.c_str(), n_vocab, hparams.n_vocab);
+            return false;
+        }
+
+        std::string word;
+        for (int i = 0; i < n_vocab; i++) {
+            uint32_t len;
+            finp.read ((char *) &len, sizeof(len));
+            fout.write((char *) &len, sizeof(len));
+
+            word.resize(len);
+            finp.read ((char *) word.data(), len);
+            fout.write((char *) word.data(), len);
+
+            vocab.token_to_id[word] = i;
+            vocab.id_to_token[i] = word;
+        }
+    }
+
+    // regexes of tensor names to be quantized
+    const std::vector<std::string> to_quant = {
+        "model/wte",
+        "model/lm_head",
+        "model/h.*/attn/c_attn/w",
+        "model/h.*/attn/c_proj/w",
+        "model/h.*/mlp/c_fc/w",
+        "model/h.*/mlp/c_proj/w",
+    };
+
+    if (!ggml_common_quantize_0(finp, fout, ftype, to_quant, {})) {
+        fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__, fname_inp.c_str());
+        return false;
+    }
+
+    finp.close();
+    fout.close();
+
+    return true;
+}
+
+// usage:
+//  ./gpt-2-quantize models/gpt-2-117M/ggml-model.bin models/gpt-2-117M/ggml-model-quant.bin type
+//
+int main(int argc, char ** argv) {
+    if (argc != 4) {
+        fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
+        ggml_print_ftypes(stderr);
+        return 1;
+    }
+
+    // needed to initialize f16 tables
+    {
+        struct ggml_init_params params = { 0, NULL, false };
+        struct ggml_context * ctx = ggml_init(params);
+        ggml_free(ctx);
+    }
+
+    const std::string fname_inp = argv[1];
+    const std::string fname_out = argv[2];
+
+    const ggml_ftype ftype = ggml_parse_ftype(argv[3]);
+
+    const int64_t t_main_start_us = ggml_time_us();
+
+    int64_t t_quantize_us = 0;
+
+    // load the model
+    {
+        const int64_t t_start_us = ggml_time_us();
+
+        if (!gpt2_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) {
+            fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
+            return 1;
+        }
+
+        t_quantize_us = ggml_time_us() - t_start_us;
+    }
+
+    // report timing
+    {
+        const int64_t t_main_end_us = ggml_time_us();
+
+        printf("\n");
+        printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f);
+        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
+    }
+
+    return 0;
+}