]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
examples : sample replit + MPT inference (#145)
authorLukas Möller <redacted>
Wed, 17 May 2023 19:58:21 +0000 (21:58 +0200)
committerGitHub <redacted>
Wed, 17 May 2023 19:58:21 +0000 (22:58 +0300)
* Add replit model

* Add unigram tokenization support

* Remove debug log

* Port alibi attn bias fix

* Remove torch input

* Fix hardcoded path

* Remove unsupported hyperparams

* Add mpt

* Add replit quantization script

* Remove debug print

* Add quantization support to mpt

* Reformat

* Remove trailing return type

* Implement stylistic changes

* use f16 in k/v memory calculations for replit/mpt

* Update context size calculation

* Add clip_qkv and alibi_bias_max support

* fix clamping implementation, remove implicit conversions

* Fix qkv if condition

* Fix replit context size calculation

* Potentially fix gcc compilation error

* Fix warning

* Adjust object overhead

* Remove dead code

examples/CMakeLists.txt
examples/mpt/CMakeLists.txt [new file with mode: 0644]
examples/mpt/convert-h5-to-ggml.py [new file with mode: 0644]
examples/mpt/main.cpp [new file with mode: 0644]
examples/mpt/quantize.cpp [new file with mode: 0644]
examples/replit/CMakeLists.txt [new file with mode: 0644]
examples/replit/convert-h5-to-ggml.py [new file with mode: 0644]
examples/replit/main.cpp [new file with mode: 0644]
examples/replit/quantize.cpp [new file with mode: 0644]
include/ggml/ggml.h
src/ggml.c

index ceca69e907d3591c9907d11f5265dd8c6382edc4..7a4bb246681d641e1a268a49b3266c8110639377 100644 (file)
@@ -23,4 +23,6 @@ add_subdirectory(whisper)
 add_subdirectory(mnist)
 add_subdirectory(gpt-neox)
 add_subdirectory(dolly-v2)
+add_subdirectory(replit)
+add_subdirectory(mpt)
 add_subdirectory(starcoder)
diff --git a/examples/mpt/CMakeLists.txt b/examples/mpt/CMakeLists.txt
new file mode 100644 (file)
index 0000000..09408f9
--- /dev/null
@@ -0,0 +1,13 @@
+#
+# mpt
+
+set(TEST_TARGET mpt)
+add_executable(${TEST_TARGET} main.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
+
+#
+# mpt-quantize
+
+set(TEST_TARGET mpt-quantize)
+add_executable(${TEST_TARGET} quantize.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
diff --git a/examples/mpt/convert-h5-to-ggml.py b/examples/mpt/convert-h5-to-ggml.py
new file mode 100644 (file)
index 0000000..9bff9d3
--- /dev/null
@@ -0,0 +1,111 @@
+import sys
+import struct
+import json
+import numpy as np
+from transformers import AutoModelForCausalLM, AutoTokenizer
+import sentencepiece.sentencepiece_model_pb2 as model
+
+if len(sys.argv) < 3:
+    print("Usage: convert-h5-to-ggml.py dir-model [use-f32]\n")
+    print("  ftype == 0 -> float32")
+    print("  ftype == 1 -> float16")
+    sys.exit(1)
+
+
+# output in the same directory as the model
+dir_model = sys.argv[1]
+fname_out = sys.argv[1] + "/ggml-model.bin"
+
+
+with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
+    hparams = json.load(f)
+
+# possible data types
+#   ftype == 0 -> float32
+#   ftype == 1 -> float16
+#
+# map from ftype to string
+ftype_str = ["f32", "f16"]
+
+ftype = 1
+if len(sys.argv) > 2:
+    ftype = int(sys.argv[2])
+    if ftype < 0 or ftype > 1:
+        print("Invalid ftype: " + str(ftype))
+        sys.exit(1)
+    fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
+
+
+tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
+model = AutoModelForCausalLM.from_pretrained(
+    dir_model, low_cpu_mem_usage=True, trust_remote_code=True
+)
+# print (model)
+
+# print(tokenizer.encode('I believe the meaning of life is'))
+
+list_vars = model.state_dict()
+for name in list_vars.keys():
+    print(name, list_vars[name].shape, list_vars[name].dtype)
+
+fout = open(fname_out, "wb")
+
+print(hparams)
+
+fout.write(struct.pack("i", 0x67676D6C))  # magic: ggml in hex
+fout.write(struct.pack("i", hparams["d_model"]))
+fout.write(struct.pack("i", hparams["max_seq_len"]))
+fout.write(struct.pack("i", hparams["n_heads"]))
+fout.write(struct.pack("i", hparams["n_layers"]))
+fout.write(struct.pack("i", hparams["vocab_size"]))
+fout.write(struct.pack("f", hparams["attn_config"]["alibi_bias_max"]))
+fout.write(struct.pack("f", hparams["attn_config"]["clip_qkv"] or 0.0))
+fout.write(struct.pack("i", ftype))
+
+
+# TODO: temporary hack to not deal with implementing the tokenizer
+dot_token = tokenizer.encode(".")[0]
+for i in range(hparams["vocab_size"]):
+    text = tokenizer.decode([dot_token, i]).encode("utf-8")
+    # remove the first byte (it's always '.')
+    text = text[1:]
+    fout.write(struct.pack("i", len(text)))
+    fout.write(text)
+
+for name in list_vars.keys():
+    data = list_vars[name].squeeze().numpy()
+    print("Processing variable: " + name + " with shape: ", data.shape)
+
+    n_dims = len(data.shape)
+
+    # ftype == 0 -> float32, ftype == 1 -> float16
+    ftype_cur = 0
+    if ftype != 0:
+        if name[-7:] == ".weight" and n_dims == 2:
+            print("  Converting to float16")
+            data = data.astype(np.float16)
+            ftype_cur = 1
+        else:
+            print("  Converting to float32")
+            data = data.astype(np.float32)
+            ftype_cur = 0
+    else:
+        if data.dtype != np.float32:
+            print("  Converting to float32")
+            data = data.astype(np.float32)
+            ftype_cur = 0
+
+    # header
+    str = name.encode("utf-8")
+    fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
+    for i in range(n_dims):
+        fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+    fout.write(str)
+
+    # data
+    data.tofile(fout)
+
+fout.close()
+
+print("Done. Output file: " + fname_out)
+print("")
diff --git a/examples/mpt/main.cpp b/examples/mpt/main.cpp
new file mode 100644 (file)
index 0000000..5a60367
--- /dev/null
@@ -0,0 +1,679 @@
+#include "ggml/ggml.h"
+
+#include "common-ggml.h"
+#include "common.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstddef>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <iostream>
+#include <map>
+#include <stdint.h>
+#include <string>
+#include <unistd.h>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+int n_ctx = 4096;
+
+// no defaults for now
+struct mpt_hparams {
+    int32_t d_model = 0;
+    int32_t max_seq_len = 0;
+    int32_t n_heads = 0;
+    int32_t n_layers = 0;
+    int32_t n_vocab = 0;
+    float alibi_bias_max = 0;
+    float clip_qkv = 0;
+    int32_t ftype = 0;
+};
+
+struct mpt_layer {
+    // pre normalization
+    struct ggml_tensor * norm_1_weight;
+
+    // attention
+    struct ggml_tensor * c_attn_wqkv_weight;
+    struct ggml_tensor * c_attn_out_proj_weight;
+
+    // post normalization
+    struct ggml_tensor * norm_2_weight;
+
+    // ff
+    struct ggml_tensor * ffn_up_proj;
+    struct ggml_tensor * ffn_down_proj;
+};
+
+struct mpt_model {
+    mpt_hparams hparams;
+
+    struct ggml_tensor * wte_weight;    // position embedding
+    struct ggml_tensor * norm_f_weight; // language model head
+
+    std::vector<mpt_layer> layers;
+
+    // key + value memory
+    struct ggml_tensor * memory_k;
+    struct ggml_tensor * memory_v;
+
+    struct ggml_context * ctx;
+    std::map<std::string, struct ggml_tensor *> tensors;
+};
+
+// load the model's weights from a file
+bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vocab) {
+    printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
+
+    auto fin = std::ifstream(fname, std::ios::binary);
+    if (!fin) {
+        fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
+        return false;
+    }
+
+    // verify magic
+    {
+        uint32_t magic;
+        fin.read((char *)&magic, sizeof(magic));
+        if (magic != 0x67676d6c) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
+            return false;
+        }
+    }
+
+    // load hparams
+    {
+        auto & hparams = model.hparams;
+
+        fin.read((char *)&hparams.d_model, sizeof(hparams.d_model));
+        fin.read((char *)&hparams.max_seq_len, sizeof(hparams.max_seq_len));
+        fin.read((char *)&hparams.n_heads, sizeof(hparams.n_heads));
+        fin.read((char *)&hparams.n_layers, sizeof(hparams.n_layers));
+        fin.read((char *)&hparams.n_vocab, sizeof(hparams.n_vocab));
+        fin.read((char *)&hparams.alibi_bias_max, sizeof(hparams.alibi_bias_max));
+        fin.read((char *)&hparams.clip_qkv, sizeof(hparams.clip_qkv));
+        fin.read((char *)&hparams.ftype, sizeof(hparams.ftype));
+
+        printf("%s: d_model        = %d\n", __func__, hparams.d_model);
+        printf("%s: max_seq_len    = %d\n", __func__, hparams.max_seq_len);
+        printf("%s: n_heads        = %d\n", __func__, hparams.n_heads);
+        printf("%s: n_layers       = %d\n", __func__, hparams.n_layers);
+        printf("%s: n_vocab        = %d\n", __func__, hparams.n_vocab);
+        printf("%s: alibi_bias_max = %f\n", __func__, hparams.alibi_bias_max);
+        printf("%s: clip_qkv       = %f\n", __func__, hparams.clip_qkv);
+        printf("%s: ftype          = %d\n", __func__, hparams.ftype);
+    }
+
+    // load vocab
+    {
+        int32_t n_vocab = model.hparams.n_vocab;
+
+        std::string word;
+        for (int i = 0; i < n_vocab; i++) {
+            uint32_t len;
+            fin.read((char *)&len, sizeof(len));
+
+            word.resize(len);
+            fin.read((char *)word.data(), len);
+
+            vocab.token_to_id[word] = i;
+            vocab.id_to_token[i] = word;
+        }
+    }
+
+    // for the big tensors, we have the option to store the data in 16-bit
+    // floats or quantized in order to save memory and also to speed up the
+    // computation
+    ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype)(model.hparams.ftype));
+    if (wtype == GGML_TYPE_COUNT) {
+        fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n", __func__, fname.c_str(),
+                model.hparams.ftype);
+        return false;
+    }
+
+    auto & ctx = model.ctx;
+
+    size_t ctx_size = 0;
+
+    {
+        const auto & hparams = model.hparams;
+
+        const size_t n_embd = hparams.d_model;
+        const size_t n_layer = hparams.n_layers;
+        const size_t n_vocab = hparams.n_vocab;
+
+        ctx_size += n_embd * n_vocab * ggml_type_sizef(wtype); // wte_weight
+        ctx_size += n_embd * ggml_type_sizef(GGML_TYPE_F32);   // norm_f_weight
+
+        ctx_size += n_layer * (n_embd * ggml_type_sizef(GGML_TYPE_F32));      // ln_1_weight
+        ctx_size += n_layer * (3 * n_embd * n_embd * ggml_type_sizef(wtype)); // attn_Wqkv_weight
+        ctx_size += n_layer * (n_embd * n_embd * ggml_type_sizef(wtype));     // attn_out_proj_weight
+        ctx_size += n_layer * (n_embd * ggml_type_sizef(GGML_TYPE_F32));      // ln_2_weight
+        ctx_size += n_layer * (4 * n_embd * n_embd * ggml_type_sizef(wtype)); // mlp_mlp_up_weight
+        ctx_size += n_layer * (n_embd * n_embd * 4 * ggml_type_sizef(wtype)); // mlp_mlp_down_weight
+
+        ctx_size += n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16); // memory_k
+        ctx_size += n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16); // memory_v
+
+        ctx_size += (1 + 6 * n_layer) * 512; // object overhead
+
+        printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size / (1024.0 * 1024.0));
+    }
+
+    // create the ggml context
+    {
+        struct ggml_init_params params = {
+            .mem_size = ctx_size,
+            .mem_buffer = NULL,
+            .no_alloc = false,
+        };
+
+        model.ctx = ggml_init(params);
+        if (!model.ctx) {
+            fprintf(stderr, "%s: ggml_init() failed\n", __func__);
+            return false;
+        }
+    }
+
+    // prepare memory for the weights
+    {
+        const auto & hparams = model.hparams;
+
+        const size_t n_embd = hparams.d_model;
+        const size_t n_layer = hparams.n_layers;
+        const size_t n_vocab = hparams.n_vocab;
+
+        model.layers.resize(n_layer);
+
+        model.wte_weight = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
+        model.norm_f_weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+
+        // map by name
+        model.tensors["transformer.wte.weight"] = model.wte_weight;
+        model.tensors["transformer.norm_f.weight"] = model.norm_f_weight;
+
+        for (int i = 0; i < n_layer; ++i) {
+            auto & layer = model.layers[i];
+
+            layer.norm_1_weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+            layer.c_attn_wqkv_weight = ggml_new_tensor_2d(ctx, wtype, n_embd, 3 * n_embd);
+            layer.c_attn_out_proj_weight = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
+            layer.norm_2_weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+            layer.ffn_up_proj = ggml_new_tensor_2d(ctx, wtype, n_embd, 4 * n_embd);
+            layer.ffn_down_proj = ggml_new_tensor_2d(ctx, wtype, 4 * n_embd, n_embd);
+
+            // map by name
+            model.tensors["transformer.blocks." + std::to_string(i) + ".norm_1.weight"] = layer.norm_1_weight;
+            model.tensors["transformer.blocks." + std::to_string(i) + ".attn.Wqkv.weight"] = layer.c_attn_wqkv_weight;
+            model.tensors["transformer.blocks." + std::to_string(i) + ".attn.out_proj.weight"] =
+                layer.c_attn_out_proj_weight;
+            model.tensors["transformer.blocks." + std::to_string(i) + ".norm_2.weight"] = layer.norm_2_weight;
+            model.tensors["transformer.blocks." + std::to_string(i) + ".ffn.up_proj.weight"] = layer.ffn_up_proj;
+            model.tensors["transformer.blocks." + std::to_string(i) + ".ffn.down_proj.weight"] = layer.ffn_down_proj;
+        }
+    }
+
+    // key + value memory
+    {
+        const auto & hparams = model.hparams;
+
+        const size_t n_embd = hparams.d_model;
+        const size_t n_layer = hparams.n_layers;
+
+        const int64_t n_mem = n_layer * n_ctx;
+        const int64_t n_elements = n_embd * n_mem;
+
+        model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+        model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+
+        const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
+
+        printf("%s: memory_size = %8.2f MB, n_mem = %lld\n", __func__, memory_size / 1024.0 / 1024.0, n_mem);
+    }
+
+    // load weights
+    {
+        int n_tensors = 0;
+        size_t total_size = 0;
+
+        printf("%s: ", __func__);
+
+        while (true) {
+            int32_t n_dims;
+            int32_t length;
+            int32_t ttype;
+
+            fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+            fin.read(reinterpret_cast<char *>(&length), sizeof(length));
+            fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
+
+            if (fin.eof()) {
+                break;
+            }
+
+            int32_t nelements = 1;
+            int32_t ne[2] = {1, 1};
+            for (int i = 0; i < n_dims; ++i) {
+                fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+                nelements *= ne[i];
+            }
+
+            std::string name(length, 0);
+            fin.read(&name[0], length);
+
+            if (model.tensors.find(name.data()) == model.tensors.end()) {
+                fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
+                return false;
+            }
+
+            auto tensor = model.tensors[name.data()];
+            if (ggml_nelements(tensor) != nelements) {
+                fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
+                return false;
+            }
+
+            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
+                fprintf(stderr,
+                        "%s: tensor '%s' has wrong shape in model file: got [%5d, "
+                        "%5d], expected [%5d, %5d]\n",
+                        __func__, name.data(), (int)tensor->ne[0], (int)tensor->ne[1], ne[0], ne[1]);
+                return false;
+            }
+
+            // for debugging
+            if (0) {
+                printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1],
+                       ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor) / 1024.0 / 1024.0, ggml_nbytes(tensor));
+            }
+
+            const size_t bpe = ggml_type_size(ggml_type(ttype));
+
+            if ((nelements * bpe) / ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
+                fprintf(stderr,
+                        "%s: tensor '%s' has wrong size in model file: got %zu, "
+                        "expected %zu\n",
+                        __func__, name.data(), ggml_nbytes(tensor), nelements * bpe);
+                return false;
+            }
+
+            fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
+
+            total_size += ggml_nbytes(tensor);
+            if (++n_tensors % 8 == 0) {
+                printf(".");
+                fflush(stdout);
+            }
+        }
+
+        printf(" done\n");
+
+        printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size / 1024.0 / 1024.0, n_tensors);
+    }
+
+    fin.close();
+
+    return true;
+}
+
+// evaluate the transformer
+//
+//   - model:     the model
+//   - n_threads: number of threads to use
+//   - n_past:    the context size so far
+//   - embd_inp:  the embeddings of the tokens in the context
+//   - embd_w:    the predicted logits for the next token
+//
+bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
+              const std::vector<gpt_vocab::id> & embd_inp, std::vector<float> & embd_w, size_t & mem_per_token) {
+    const int N = embd_inp.size();
+
+    const auto & hparams = model.hparams;
+
+    const int n_embd = hparams.d_model;
+    const int n_layer = hparams.n_layers;
+    const int n_head = hparams.n_heads;
+    const int n_vocab = hparams.n_vocab;
+
+    static size_t buf_size = 256u * 1024 * 1024;
+    static void * buf = malloc(buf_size);
+
+    if (mem_per_token > 0 && mem_per_token * N > buf_size) {
+        const size_t buf_size_new = 1.1 * (mem_per_token * N); // add 10% to account for ggml object overhead
+        // printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__,
+        // buf_size, buf_size_new);
+
+        // reallocate
+        buf_size = buf_size_new;
+        buf = realloc(buf, buf_size);
+        if (buf == nullptr) {
+            fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
+            return false;
+        }
+    }
+
+    struct ggml_init_params params = {
+        .mem_size = buf_size,
+        .mem_buffer = buf,
+        .no_alloc = false,
+    };
+
+    struct ggml_context * ctx0 = ggml_init(params);
+    struct ggml_cgraph gf = {.n_threads = n_threads};
+
+    struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    memcpy(embd->data, embd_inp.data(), N * ggml_element_size(embd));
+
+    struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.wte_weight, embd);
+
+    for (int il = 0; il < n_layer; ++il) {
+
+        struct ggml_tensor * cur;
+
+        // a = self.ln_1(x)
+        {
+            cur = ggml_norm(ctx0, inpL);
+
+            cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].norm_1_weight, cur), cur);
+        }
+
+        // self-attention
+        //  b, _, past_key_value = self.attn(a, past_key_value=past_key_value,
+        //  attn_bias=attn_bias, attention_mask=attention_mask,
+        //  is_causal=is_causal)
+        {
+
+            // compute QKV
+            cur = ggml_mul_mat(ctx0, model.layers[il].c_attn_wqkv_weight, cur);
+
+            if (model.hparams.clip_qkv > 0.0f) {
+                cur = ggml_clamp(ctx0, cur, -model.hparams.clip_qkv, model.hparams.clip_qkv);
+            }
+
+            struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0 * sizeof(float) * n_embd);
+            struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1 * sizeof(float) * n_embd);
+            struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2 * sizeof(float) * n_embd);
+
+            // store key and value to memory
+            {
+                struct ggml_tensor * k =
+                    ggml_view_1d(ctx0, model.memory_k, N * n_embd,
+                                 (ggml_element_size(model.memory_k) * n_embd) * (il * n_ctx + n_past));
+                struct ggml_tensor * v =
+                    ggml_view_1d(ctx0, model.memory_v, N * n_embd,
+                                 (ggml_element_size(model.memory_v) * n_embd) * (il * n_ctx + n_past));
+
+                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
+            }
+
+            // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0,
+            // 2, 1, 3) [64, N, 12]
+            struct ggml_tensor * Q = ggml_permute(
+                ctx0, ggml_cpy(ctx0, Qcur, ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd / n_head, n_head, N)), 0, 2,
+                1, 3);
+
+            // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1,
+            // 3) [64, n_past + N, 12]
+            struct ggml_tensor * K =
+                ggml_permute(ctx0,
+                             ggml_reshape_3d(ctx0,
+                                             ggml_view_1d(ctx0, model.memory_k, (n_past + N) * n_embd,
+                                                          il * n_ctx * ggml_element_size(model.memory_k) * n_embd),
+                                             n_embd / n_head, n_head, n_past + N),
+                             0, 2, 1, 3);
+            // K * Q
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+
+            // KQ_scaled = KQ / sqrt(n_embd/n_head)
+            struct ggml_tensor * KQ_scaled =
+                ggml_scale(ctx0, KQ, ggml_new_f32(ctx0, 1.0f / sqrt(float(n_embd) / n_head)));
+
+            struct ggml_tensor * KQ_scaled_alibi =
+                ggml_alibi(ctx0, KQ_scaled, n_past, n_head, model.hparams.alibi_bias_max);
+
+            // KQ_masked = mask_past(KQ_scaled)
+            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past);
+
+            // KQ = soft_max(KQ_masked)
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+
+            // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1,
+            // 2, 0, 3).contiguous() [n_past + N, 64, 12]
+            struct ggml_tensor * V_trans = ggml_cpy(
+                ctx0,
+                ggml_permute(ctx0,
+                             ggml_reshape_3d(ctx0,
+                                             ggml_view_1d(ctx0, model.memory_v, (n_past + N) * n_embd,
+                                                          il * n_ctx * ggml_element_size(model.memory_v) * n_embd),
+                                             n_embd / n_head, n_head, n_past + N),
+                             1, 2, 0, 3),
+                ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd / n_head, n_head));
+
+            // KQV = transpose(V) * KQ_soft_max
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+
+            // KQV_merged = KQV.permute(0, 2, 1, 3)
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+            // cur = KQV_merged.contiguous().view(n_embd, N)
+            cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+
+            // projection
+            { cur = ggml_mul_mat(ctx0, model.layers[il].c_attn_out_proj_weight, cur); }
+        }
+
+        inpL = ggml_add(ctx0, inpL, cur);
+
+        // m = self.ln_2(x)
+        {
+            cur = ggml_norm(ctx0, inpL);
+
+            cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].norm_2_weight, cur), cur);
+        }
+
+        // n = self.mlp(m)
+        {
+
+            cur = ggml_mul_mat(ctx0, model.layers[il].ffn_up_proj, cur);
+
+            // GELU activation
+            cur = ggml_gelu(ctx0, cur);
+
+            // projection
+            // cur = proj_w*cur + proj_b
+            cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down_proj, cur);
+        }
+
+        // x = x + n
+        inpL = ggml_add(ctx0, inpL, cur);
+    }
+
+    // norm
+    {
+        inpL = ggml_norm(ctx0, inpL);
+        // inpL = ln_f_g*inpL
+        inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.norm_f_weight, inpL), inpL);
+    }
+
+    // output embedding weight tied to input embedding
+    inpL = ggml_mul_mat(ctx0, model.wte_weight, inpL);
+
+    // logits -> probs
+    // inpL = ggml_soft_max(ctx0, inpL);
+
+    // run the computation
+    ggml_build_forward_expand(&gf, inpL);
+    ggml_graph_compute(ctx0, &gf);
+
+    // std::cout << "Qcur" << std::endl;
+    // print_tensor(Qcur);
+
+    // if (n_past%100 == 0) {
+    // ggml_graph_print(&gf);
+    // ggml_graph_dump_dot(&gf, NULL, "mpt-model.dot");
+    // }
+
+    // return result for just the last token
+    embd_w.resize(n_vocab);
+    memcpy(embd_w.data(), (float *)ggml_get_data(inpL) + (n_vocab * (N - 1)), sizeof(float) * n_vocab);
+
+    if (mem_per_token == 0) {
+        mem_per_token = ggml_used_mem(ctx0) / N;
+    }
+    // printf("used_mem = %zu\n", ggml_used_mem(ctx0));
+
+    ggml_free(ctx0);
+
+    return true;
+}
+
+int main(int argc, char ** argv) {
+    const int64_t t_main_start_us = ggml_time_us();
+
+    gpt_params params;
+    params.model = "";
+
+    if (gpt_params_parse(argc, argv, params) == false) {
+        return 1;
+    }
+
+    if (params.seed < 0) {
+        params.seed = time(NULL);
+    }
+
+    printf("%s: seed = %d\n", __func__, params.seed);
+
+    std::mt19937 rng(params.seed);
+    if (params.prompt.empty()) {
+        if (!isatty(STDIN_FILENO)) {
+            std::string line;
+            while (std::getline(std::cin, line)) {
+                params.prompt = params.prompt + "\n" + line;
+            }
+        } else {
+            params.prompt = gpt_random_prompt(rng);
+        }
+    }
+
+    int64_t t_load_us = 0;
+
+    gpt_vocab vocab;
+    mpt_model model;
+
+    // load the model
+    {
+        const int64_t t_start_us = ggml_time_us();
+
+        if (!mpt_model_load(params.model, model, vocab)) {
+            fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
+            return 1;
+        }
+
+        t_load_us = ggml_time_us() - t_start_us;
+    }
+
+    int n_past = 0;
+
+    int64_t t_sample_us = 0;
+    int64_t t_predict_us = 0;
+
+    std::vector<float> logits;
+
+    // tokenize the prompt
+    std::vector<int> embd_inp = ::gpt_tokenize(vocab, params.prompt);
+
+    printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
+
+    for (int i = 0; i < embd_inp.size(); i++) {
+        printf("%s: token[%d] = %6d\n", __func__, i, embd_inp[i]);
+        // vocab.id_to_token.at(embd_inp[i]).c_str()
+    }
+    printf("\n");
+
+    params.n_predict = std::min(params.n_predict, n_ctx - (int)embd_inp.size());
+
+    std::vector<gpt_vocab::id> embd;
+
+    // determine the required inference memory per token:
+    size_t mem_per_token = 0;
+    mpt_eval(model, params.n_threads, 0, {0, 1, 2, 3}, logits, mem_per_token);
+
+    for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
+        // predict
+        if (embd.size() > 0) {
+            const int64_t t_start_us = ggml_time_us();
+
+            if (!mpt_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
+                printf("Failed to predict\n");
+                return 1;
+            }
+
+            t_predict_us += ggml_time_us() - t_start_us;
+        }
+
+        n_past += embd.size();
+        embd.clear();
+
+        if (i >= embd_inp.size()) {
+            // sample next token
+            const int top_k = params.top_k;
+            const float top_p = params.top_p;
+            const float temp = params.temp;
+
+            const int n_vocab = model.hparams.n_vocab;
+
+            gpt_vocab::id id = 0;
+
+            {
+                const int64_t t_start_sample_us = ggml_time_us();
+
+                id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
+
+                t_sample_us += ggml_time_us() - t_start_sample_us;
+            }
+
+            // add it to the context
+            embd.push_back(id);
+        } else {
+            // if here, it means we are still processing the input prompt
+            for (int k = i; k < embd_inp.size(); k++) {
+                embd.push_back(embd_inp[k]);
+                if (embd.size() > params.n_batch) {
+                    break;
+                }
+            }
+            i += embd.size() - 1;
+        }
+
+        // display text
+        for (auto id : embd) {
+            printf("%s", vocab.id_to_token[id].c_str());
+        }
+        fflush(stdout);
+
+        // end of text token
+        if (embd.back() == 0) {
+            break;
+        }
+    }
+
+    // report timing
+    {
+        const int64_t t_main_end_us = ggml_time_us();
+
+        printf("\n\n");
+        printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
+        printf("%s:     load time = %8.2f ms\n", __func__, t_load_us / 1000.0f);
+        printf("%s:   sample time = %8.2f ms\n", __func__, t_sample_us / 1000.0f);
+        printf("%s:  predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us / 1000.0f,
+               t_predict_us / 1000.0f / n_past);
+        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f);
+    }
+
+    ggml_free(model.ctx);
+
+    return 0;
+}
diff --git a/examples/mpt/quantize.cpp b/examples/mpt/quantize.cpp
new file mode 100644 (file)
index 0000000..8f32bdd
--- /dev/null
@@ -0,0 +1,180 @@
+#include "ggml/ggml.h"
+
+#include "common-ggml.h"
+#include "common.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <regex>
+#include <string>
+#include <vector>
+
+struct mpt_hparams {
+    int32_t d_model = 0;
+    int32_t max_seq_len = 0;
+    int32_t n_heads = 0;
+    int32_t n_layers = 0;
+    int32_t n_vocab = 0;
+    float alibi_bias_max = 0;
+    float clip_qkv = 0;
+    int32_t ftype = 0;
+};
+
+// quantize a model
+bool mpt_model_quantize(const std::string & fname_inp,
+                        const std::string & fname_out, ggml_ftype ftype) {
+
+    printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
+
+    auto finp = std::ifstream(fname_inp, std::ios::binary);
+    if (!finp) {
+        fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__,
+                fname_inp.c_str());
+        return false;
+    }
+
+    auto fout = std::ofstream(fname_out, std::ios::binary);
+    if (!fout) {
+        fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__,
+                fname_out.c_str());
+        return false;
+    }
+
+    // verify magic
+    {
+        uint32_t magic;
+        finp.read((char *)&magic, sizeof(magic));
+        if (magic != 0x67676d6c) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n",
+                    __func__, fname_inp.c_str());
+            return false;
+        }
+
+        fout.write((char *)&magic, sizeof(magic));
+    }
+
+    mpt_hparams hparams;
+
+    // load hparams
+    {
+        finp.read((char *)&hparams.d_model, sizeof(hparams.d_model));
+        finp.read((char *)&hparams.max_seq_len, sizeof(hparams.max_seq_len));
+        finp.read((char *)&hparams.n_heads, sizeof(hparams.n_heads));
+        finp.read((char *)&hparams.n_layers, sizeof(hparams.n_layers));
+        finp.read((char *)&hparams.n_vocab, sizeof(hparams.n_vocab));
+        finp.read((char *)&hparams.alibi_bias_max, sizeof(hparams.alibi_bias_max));
+        finp.read((char *)&hparams.clip_qkv, sizeof(hparams.clip_qkv));
+        finp.read((char *)&hparams.ftype, sizeof(hparams.ftype));
+
+        printf("%s: d_model = %d\n", __func__, hparams.d_model);
+        printf("%s: max_seq_len = %d\n", __func__, hparams.max_seq_len);
+        printf("%s: n_heads = %d\n", __func__, hparams.n_heads);
+        printf("%s: n_layers = %d\n", __func__, hparams.n_layers);
+        printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
+        printf("%s: alibi_bias_max = %f\n", __func__, hparams.alibi_bias_max);
+        printf("%s: clip_qkv = %f\n", __func__, hparams.clip_qkv);
+        printf("%s: ftype = %d\n", __func__, hparams.ftype);
+
+        fout.write((char *)&hparams.d_model, sizeof(hparams.d_model));
+        fout.write((char *)&hparams.max_seq_len, sizeof(hparams.max_seq_len));
+        fout.write((char *)&hparams.n_heads, sizeof(hparams.n_heads));
+        fout.write((char *)&hparams.n_layers, sizeof(hparams.n_layers));
+        fout.write((char *)&hparams.n_vocab, sizeof(hparams.n_vocab));
+        fout.write((char *)&hparams.alibi_bias_max, sizeof(hparams.alibi_bias_max));
+        fout.write((char *)&hparams.clip_qkv, sizeof(hparams.clip_qkv));
+        fout.write((char *)&ftype, sizeof(hparams.ftype));
+    }
+
+    // load vocab
+    {
+        const int32_t n_vocab = hparams.n_vocab;
+
+        std::string word;
+        for (int i = 0; i < n_vocab; i++) {
+            uint32_t len;
+            finp.read((char *)&len, sizeof(len));
+            fout.write((char *)&len, sizeof(len));
+
+            word.resize(len);
+            finp.read((char *)word.data(), len);
+            fout.write((char *)word.data(), len);
+        }
+    }
+
+    printf("%s: quantizing tensors\n", __func__);
+
+    // regexes of tensor names to be quantized
+    const std::vector<std::string> to_quant = {
+        ".*weight",
+    };
+
+    if (!ggml_common_quantize_0(finp, fout, ftype, to_quant, {})) {
+        fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__,
+                fname_inp.c_str());
+        return false;
+    }
+
+    finp.close();
+    fout.close();
+
+    return true;
+}
+
+// usage:
+//  ./mpt-quantize models/mpt/ggml-model.bin
+//  models/mpt/ggml-model-quant.bin type
+//
+int main(int argc, char ** argv) {
+    if (argc != 4) {
+        fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n",
+                argv[0]);
+        ggml_print_ftypes(stderr);
+        return 1;
+    }
+
+    // needed to initialize f16 tables
+    {
+        struct ggml_init_params params = {0, NULL, false};
+        struct ggml_context * ctx = ggml_init(params);
+        ggml_free(ctx);
+    }
+
+    const std::string fname_inp = argv[1];
+    const std::string fname_out = argv[2];
+
+    const ggml_ftype ftype = ggml_parse_ftype(argv[3]);
+
+    const int64_t t_main_start_us = ggml_time_us();
+
+    int64_t t_quantize_us = 0;
+
+    // load the model
+    {
+        const int64_t t_start_us = ggml_time_us();
+
+        if (!mpt_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) {
+            fprintf(stderr, "%s: failed to quantize model from '%s'\n",
+                    __func__, fname_inp.c_str());
+            return 1;
+        }
+
+        t_quantize_us = ggml_time_us() - t_start_us;
+    }
+
+    // report timing
+    {
+        const int64_t t_main_end_us = ggml_time_us();
+
+        printf("\n");
+        printf("%s: quantize time = %8.2f ms\n", __func__,
+               t_quantize_us / 1000.0f);
+        printf("%s:    total time = %8.2f ms\n", __func__,
+               (t_main_end_us - t_main_start_us) / 1000.0f);
+    }
+
+    return 0;
+}
diff --git a/examples/replit/CMakeLists.txt b/examples/replit/CMakeLists.txt
new file mode 100644 (file)
index 0000000..696b7f9
--- /dev/null
@@ -0,0 +1,13 @@
+#
+# replit
+
+set(TEST_TARGET replit)
+add_executable(${TEST_TARGET} main.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
+
+#
+# replit-quantize
+
+set(TEST_TARGET replit-quantize)
+add_executable(${TEST_TARGET} quantize.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
diff --git a/examples/replit/convert-h5-to-ggml.py b/examples/replit/convert-h5-to-ggml.py
new file mode 100644 (file)
index 0000000..310074b
--- /dev/null
@@ -0,0 +1,113 @@
+from pathlib import Path
+import sys
+import struct
+import json
+import numpy as np
+from transformers import AutoModelForCausalLM, AutoTokenizer
+import sentencepiece.sentencepiece_model_pb2 as model
+
+if len(sys.argv) < 3:
+    print("Usage: convert-h5-to-ggml.py dir-model [use-f32]\n")
+    print("  ftype == 0 -> float32")
+    print("  ftype == 1 -> float16")
+    sys.exit(1)
+
+
+# output in the same directory as the model
+dir_model = sys.argv[1]
+fname_out = sys.argv[1] + "/ggml-model.bin"
+
+
+with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
+    hparams = json.load(f)
+
+sp_proto = model.ModelProto()
+sp_proto.ParseFromString(open(Path(sys.argv[1]) / "spiece.model", "rb").read())
+
+
+# possible data types
+#   ftype == 0 -> float32
+#   ftype == 1 -> float16
+#
+# map from ftype to string
+ftype_str = ["f32", "f16"]
+
+ftype = 1
+if len(sys.argv) > 2:
+    ftype = int(sys.argv[2])
+    if ftype < 0 or ftype > 1:
+        print("Invalid ftype: " + str(ftype))
+        sys.exit(1)
+    fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
+
+
+tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
+model = AutoModelForCausalLM.from_pretrained(
+    dir_model, low_cpu_mem_usage=True, trust_remote_code=True
+)
+# print (model)
+
+# print(tokenizer.encode('I believe the meaning of life is'))
+
+list_vars = model.state_dict()
+for name in list_vars.keys():
+    print(name, list_vars[name].shape, list_vars[name].dtype)
+
+fout = open(fname_out, "wb")
+
+print(hparams)
+
+fout.write(struct.pack("i", 0x67676D6C))  # magic: ggml in hex
+fout.write(struct.pack("i", hparams["d_model"]))
+fout.write(struct.pack("i", hparams["max_seq_len"]))
+fout.write(struct.pack("i", hparams["n_heads"]))
+fout.write(struct.pack("i", hparams["n_layers"]))
+fout.write(struct.pack("i", hparams["vocab_size"]))
+fout.write(struct.pack("i", ftype))
+
+
+# TODO: temporary hack to not deal with implementing the tokenizer
+for piece in sp_proto.pieces:
+    encoded_piece = piece.piece.encode("utf-8")
+    fout.write(struct.pack("i", len(encoded_piece)))
+    fout.write(encoded_piece)
+    fout.write(struct.pack("f", piece.score))
+
+
+for name in list_vars.keys():
+    data = list_vars[name].squeeze().numpy()
+    print("Processing variable: " + name + " with shape: ", data.shape)
+
+    n_dims = len(data.shape)
+
+    # ftype == 0 -> float32, ftype == 1 -> float16
+    ftype_cur = 0
+    if ftype != 0:
+        if name[-7:] == ".weight" and n_dims == 2:
+            print("  Converting to float16")
+            data = data.astype(np.float16)
+            ftype_cur = 1
+        else:
+            print("  Converting to float32")
+            data = data.astype(np.float32)
+            ftype_cur = 0
+    else:
+        if data.dtype != np.float32:
+            print("  Converting to float32")
+            data = data.astype(np.float32)
+            ftype_cur = 0
+
+    # header
+    str = name.encode("utf-8")
+    fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
+    for i in range(n_dims):
+        fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+    fout.write(str)
+
+    # data
+    data.tofile(fout)
+
+fout.close()
+
+print("Done. Output file: " + fname_out)
+print("")
diff --git a/examples/replit/main.cpp b/examples/replit/main.cpp
new file mode 100644 (file)
index 0000000..e01db75
--- /dev/null
@@ -0,0 +1,767 @@
+#include "ggml/ggml.h"
+
+#include "common-ggml.h"
+#include "common.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstddef>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <iostream>
+#include <map>
+#include <stdint.h>
+#include <string>
+#include <unistd.h>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+using piece_t = std::pair<std::size_t, float>;
+using piece_map_t = std::unordered_map<std::string, piece_t>;
+
+struct replit_tokenizer {
+    gpt_vocab raw_vocab;
+    piece_map_t piece_map;
+    std::vector<std::string> vocab;
+};
+
+std::pair<std::vector<std::size_t>, float> encode_word(const std::string & word, const piece_map_t & model) {
+    std::vector<int> best_segmentations_starts(word.length() + 1, -1);
+    best_segmentations_starts[0] = 0;
+
+    std::vector<float> best_segmentations_scores(word.length() + 1, -std::numeric_limits<float>::infinity());
+    best_segmentations_scores[0] = 1.0;
+
+    for (int start_idx = 0; start_idx < word.length(); ++start_idx) {
+        float best_score_at_start = best_segmentations_scores[start_idx];
+        for (int end_idx = start_idx + 1; end_idx <= word.length(); ++end_idx) {
+            std::string token = word.substr(start_idx, end_idx - start_idx);
+            if (model.count(token) && best_score_at_start != -std::numeric_limits<float>::infinity()) {
+                float token_score = model.at(token).second;
+                float score = token_score + best_score_at_start;
+                if (best_segmentations_scores[end_idx] == -std::numeric_limits<float>::infinity() ||
+                    best_segmentations_scores[end_idx] > score) {
+                    best_segmentations_starts[end_idx] = start_idx;
+                    best_segmentations_scores[end_idx] = score;
+                }
+            }
+        }
+    }
+
+    if (best_segmentations_scores.back() == -std::numeric_limits<float>::infinity()) {
+        return std::make_pair(std::vector<std::size_t>{0}, 0.0f);
+    }
+
+    float score = best_segmentations_scores.back();
+    int start = best_segmentations_starts.back();
+    int end = word.length();
+    std::vector<std::size_t> tokens;
+    while (start != 0) {
+        const auto token_id = model.at(word.substr(start, end - start)).first;
+        tokens.insert(tokens.begin(), token_id);
+        int next_start = best_segmentations_starts[start];
+        end = start;
+        start = next_start;
+    }
+    const auto token_id = model.at(word.substr(start, end - start)).first;
+    tokens.insert(tokens.begin(), token_id);
+    return std::make_pair(tokens, score);
+}
+
+bool replit_tokenizer_load(replit_tokenizer & tokenizer, std::istream & fin, int max_vocab_size) {
+
+    for (std::size_t i = 0; i < max_vocab_size; i++) {
+
+        uint32_t len;
+        fin.read((char *)&len, sizeof(len));
+
+        std::string word;
+        word.resize(len);
+        fin.read((char *)word.data(), len);
+
+        float score;
+        fin.read((char *)&score, sizeof(score));
+
+        tokenizer.piece_map[word] = std::make_pair(i, -score);
+        tokenizer.raw_vocab.id_to_token[i] = word;
+    }
+
+    return true;
+}
+
+std::string replace_all(const std::string & str,    // where to work
+                        const std::string & find,   // substitute 'find'
+                        const std::string & replace //      by 'replace'
+) {
+    using namespace std;
+    string result;
+    size_t find_len = find.size();
+    size_t pos, from = 0;
+    while (string::npos != (pos = str.find(find, from))) {
+        result.append(str, from, pos - from);
+        result.append(replace);
+        from = pos + find_len;
+    }
+    result.append(str, from, string::npos);
+    return result;
+}
+
+std::string ws_symbol = "\342\226\201";
+std::vector<std::size_t> replit_tokenizer_tokenize(replit_tokenizer & tokenizer, const std::string & text) {
+    std::vector<std::size_t> tokens;
+    auto normalized_text = replace_all(text, " ", ws_symbol);
+    auto tokenized = encode_word(normalized_text, tokenizer.piece_map);
+
+    return tokenized.first;
+}
+
+std::string replit_tokenizer_detokenize(replit_tokenizer & tokenizer, const std::vector<std::size_t> & tokens) {
+    std::string text;
+    for (auto token : tokens) {
+        text += tokenizer.raw_vocab.id_to_token[token];
+    }
+    auto denormalized_text = replace_all(text, ws_symbol, " ");
+    return denormalized_text;
+}
+
+// no defaults for now
+struct mpt_hparams {
+    int32_t d_model = 0;
+    int32_t max_seq_len = 0;
+    int32_t n_heads = 0;
+    int32_t n_layers = 0;
+    int32_t n_vocab = 0;
+    int32_t ftype = 0;
+};
+
+struct replit_layer {
+    // pre normalization
+    struct ggml_tensor * ln_1_weight;
+
+    // attention
+    struct ggml_tensor * c_attn_wqkv_weight;
+
+    struct ggml_tensor * c_attn_out_proj_weight;
+
+    // post normalization
+    struct ggml_tensor * ln_2_weight;
+
+    // ff
+    struct ggml_tensor * c_mlp_mlp_up_weight;
+
+    struct ggml_tensor * c_mlp_mlp_down_weight;
+};
+
+struct replit_model {
+    mpt_hparams hparams;
+
+    struct ggml_tensor * wte_weight;  // position embedding
+    struct ggml_tensor * ln_f_weight; // language model head
+
+    std::vector<replit_layer> layers;
+
+    // key + value memory
+    struct ggml_tensor * memory_k;
+    struct ggml_tensor * memory_v;
+
+    struct ggml_context * ctx;
+    std::map<std::string, struct ggml_tensor *> tensors;
+};
+
+// load the model's weights from a file
+bool replit_model_load(const std::string & fname, replit_model & model, replit_tokenizer & vocab) {
+    printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
+
+    auto fin = std::ifstream(fname, std::ios::binary);
+    if (!fin) {
+        fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
+        return false;
+    }
+
+    // verify magic
+    {
+        uint32_t magic;
+        fin.read((char *)&magic, sizeof(magic));
+        if (magic != 0x67676d6c) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
+            return false;
+        }
+    }
+
+    // load hparams
+    {
+        auto & hparams = model.hparams;
+
+        fin.read((char *)&hparams.d_model, sizeof(hparams.d_model));
+        fin.read((char *)&hparams.max_seq_len, sizeof(hparams.max_seq_len));
+        fin.read((char *)&hparams.n_heads, sizeof(hparams.n_heads));
+        fin.read((char *)&hparams.n_layers, sizeof(hparams.n_layers));
+        fin.read((char *)&hparams.n_vocab, sizeof(hparams.n_vocab));
+        fin.read((char *)&hparams.ftype, sizeof(hparams.ftype));
+
+        printf("%s: d_model       = %d\n", __func__, hparams.d_model);
+        printf("%s: max_seq_len   = %d\n", __func__, hparams.max_seq_len);
+        printf("%s: n_heads       = %d\n", __func__, hparams.n_heads);
+        printf("%s: n_layers      = %d\n", __func__, hparams.n_layers);
+        printf("%s: n_vocab      = %d\n", __func__, hparams.n_vocab);
+        printf("%s: ftype   = %d\n", __func__, hparams.ftype);
+    }
+
+    // load vocab
+    replit_tokenizer_load(vocab, fin, model.hparams.n_vocab);
+
+    // for the big tensors, we have the option to store the data in 16-bit
+    // floats or quantized in order to save memory and also to speed up the
+    // computation
+    ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype)(model.hparams.ftype));
+    if (wtype == GGML_TYPE_COUNT) {
+        fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n", __func__, fname.c_str(),
+                model.hparams.ftype);
+        return false;
+    }
+
+    auto & ctx = model.ctx;
+
+    size_t ctx_size = 0;
+
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd = hparams.d_model;
+        const int n_layer = hparams.n_layers;
+        const int n_ctx = hparams.max_seq_len;
+        const int n_vocab = hparams.n_vocab;
+
+        ctx_size += n_embd * n_vocab * ggml_type_sizef(wtype); // wte_weight
+        ctx_size += n_embd * ggml_type_sizef(GGML_TYPE_F32);   // ln_f_weight
+
+        ctx_size += n_layer * (n_embd * ggml_type_sizef(GGML_TYPE_F32));      // ln_1_weight
+        ctx_size += n_layer * (3 * n_embd * n_embd * ggml_type_sizef(wtype)); // attn_Wqkv_weight
+        ctx_size += n_layer * (n_embd * n_embd * ggml_type_sizef(wtype));     // attn_out_proj_weight
+        ctx_size += n_layer * (n_embd * ggml_type_sizef(GGML_TYPE_F32));      // ln_2_weight
+        ctx_size += n_layer * (4 * n_embd * n_embd * ggml_type_sizef(wtype)); // mlp_mlp_up_weight
+        ctx_size += n_layer * (n_embd * n_embd * 4 * ggml_type_sizef(wtype)); // mlp_mlp_down_weight
+
+        ctx_size += n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16); // memory_k
+        ctx_size += n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16); // memory_v
+
+        ctx_size += (1 + 6 * n_layer) * 512; // object overhead
+
+        printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size / (1024.0 * 1024.0));
+    }
+
+    // create the ggml context
+    {
+        struct ggml_init_params params = {
+            .mem_size = ctx_size,
+            .mem_buffer = NULL,
+            .no_alloc = false,
+        };
+
+        model.ctx = ggml_init(params);
+        if (!model.ctx) {
+            fprintf(stderr, "%s: ggml_init() failed\n", __func__);
+            return false;
+        }
+    }
+
+    // prepare memory for the weights
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd = hparams.d_model;
+        const int n_layer = hparams.n_layers;
+        const int n_vocab = hparams.n_vocab;
+
+        model.layers.resize(n_layer);
+
+        model.wte_weight = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
+        model.ln_f_weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+
+        // map by name
+        model.tensors["transformer.wte.weight"] = model.wte_weight;
+        model.tensors["transformer.ln_f.weight"] = model.ln_f_weight;
+
+        for (int i = 0; i < n_layer; ++i) {
+            auto & layer = model.layers[i];
+
+            layer.ln_1_weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+            layer.c_attn_wqkv_weight = ggml_new_tensor_2d(ctx, wtype, n_embd, 3 * n_embd);
+            layer.c_attn_out_proj_weight = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
+            layer.ln_2_weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+            layer.c_mlp_mlp_up_weight = ggml_new_tensor_2d(ctx, wtype, n_embd, 4 * n_embd);
+            layer.c_mlp_mlp_down_weight = ggml_new_tensor_2d(ctx, wtype, 4 * n_embd, n_embd);
+
+            // map by name
+            model.tensors["transformer.blocks." + std::to_string(i) + ".ln_1.weight"] = layer.ln_1_weight;
+            model.tensors["transformer.blocks." + std::to_string(i) + ".attn.Wqkv.weight"] = layer.c_attn_wqkv_weight;
+            model.tensors["transformer.blocks." + std::to_string(i) + ".attn.out_proj.weight"] =
+                layer.c_attn_out_proj_weight;
+            model.tensors["transformer.blocks." + std::to_string(i) + ".ln_2.weight"] = layer.ln_2_weight;
+            model.tensors["transformer.blocks." + std::to_string(i) + ".mlp.mlp_up.weight"] = layer.c_mlp_mlp_up_weight;
+            model.tensors["transformer.blocks." + std::to_string(i) + ".mlp.mlp_down.weight"] =
+                layer.c_mlp_mlp_down_weight;
+        }
+    }
+
+    // key + value memory
+    {
+        const auto & hparams = model.hparams;
+
+        const int n_embd = hparams.d_model;
+        const int n_layer = hparams.n_layers;
+        const int n_ctx = hparams.max_seq_len;
+
+        const int64_t n_mem = n_layer * n_ctx;
+        const int64_t n_elements = n_embd * n_mem;
+
+        model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+        model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+
+        const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
+
+        printf("%s: memory_size = %8.2f MB, n_mem = %lld\n", __func__, memory_size / 1024.0 / 1024.0, n_mem);
+    }
+
+    // load weights
+    {
+        int n_tensors = 0;
+        size_t total_size = 0;
+
+        printf("%s: ", __func__);
+
+        while (true) {
+            int32_t n_dims;
+            int32_t length;
+            int32_t ttype;
+
+            fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+            fin.read(reinterpret_cast<char *>(&length), sizeof(length));
+            fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
+
+            if (fin.eof()) {
+                break;
+            }
+
+            int32_t nelements = 1;
+            int32_t ne[2] = {1, 1};
+            for (int i = 0; i < n_dims; ++i) {
+                fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+                nelements *= ne[i];
+            }
+
+            std::string name(length, 0);
+            fin.read(&name[0], length);
+
+            if (model.tensors.find(name.data()) == model.tensors.end()) {
+                fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
+                return false;
+            }
+
+            auto tensor = model.tensors[name.data()];
+            if (ggml_nelements(tensor) != nelements) {
+                fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
+                return false;
+            }
+
+            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
+                fprintf(stderr,
+                        "%s: tensor '%s' has wrong shape in model file: got [%5d, "
+                        "%5d], expected [%5d, %5d]\n",
+                        __func__, name.data(), (int)tensor->ne[0], (int)tensor->ne[1], ne[0], ne[1]);
+                return false;
+            }
+
+            // for debugging
+            if (0) {
+                printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1],
+                       ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor) / 1024.0 / 1024.0, ggml_nbytes(tensor));
+            }
+
+            const size_t bpe = ggml_type_size(ggml_type(ttype));
+
+            if ((nelements * bpe) / ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
+                fprintf(stderr,
+                        "%s: tensor '%s' has wrong size in model file: got %zu, "
+                        "expected %zu\n",
+                        __func__, name.data(), ggml_nbytes(tensor), nelements * bpe);
+                return false;
+            }
+
+            fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
+
+            total_size += ggml_nbytes(tensor);
+            if (++n_tensors % 8 == 0) {
+                printf(".");
+                fflush(stdout);
+            }
+        }
+
+        printf(" done\n");
+
+        printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size / 1024.0 / 1024.0, n_tensors);
+    }
+
+    fin.close();
+
+    return true;
+}
+
+// evaluate the transformer
+//
+//   - model:     the model
+//   - n_threads: number of threads to use
+//   - n_past:    the context size so far
+//   - embd_inp:  the embeddings of the tokens in the context
+//   - embd_w:    the predicted logits for the next token
+//
+bool replit_eval(const replit_model & model, const int n_threads, const int n_past,
+                 const std::vector<gpt_vocab::id> & embd_inp, std::vector<float> & embd_w, size_t & mem_per_token) {
+    const int N = embd_inp.size();
+
+    const auto & hparams = model.hparams;
+
+    const int n_embd = hparams.d_model;
+    const int n_layer = hparams.n_layers;
+    const int n_ctx = hparams.max_seq_len;
+    const int n_head = hparams.n_heads;
+    const int n_vocab = hparams.n_vocab;
+
+    static size_t buf_size = 256u * 1024 * 1024;
+    static void * buf = malloc(buf_size);
+
+    if (mem_per_token > 0 && mem_per_token * N > buf_size) {
+        const size_t buf_size_new = 1.1 * (mem_per_token * N); // add 10% to account for ggml object overhead
+        // printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__,
+        // buf_size, buf_size_new);
+
+        // reallocate
+        buf_size = buf_size_new;
+        buf = realloc(buf, buf_size);
+        if (buf == nullptr) {
+            fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
+            return false;
+        }
+    }
+
+    struct ggml_init_params params = {
+        .mem_size = buf_size,
+        .mem_buffer = buf,
+        .no_alloc = false,
+    };
+
+    struct ggml_context * ctx0 = ggml_init(params);
+    struct ggml_cgraph gf = {.n_threads = n_threads};
+
+    struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+    memcpy(embd->data, embd_inp.data(), N * ggml_element_size(embd));
+
+    struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.wte_weight, embd);
+
+    for (int il = 0; il < n_layer; ++il) {
+
+        struct ggml_tensor * cur;
+
+        // a = self.ln_1(x)
+        {
+            cur = ggml_norm(ctx0, inpL);
+
+            cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ln_1_weight, cur), cur);
+        }
+
+        // self-attention
+        //  b, _, past_key_value = self.attn(a, past_key_value=past_key_value,
+        //  attn_bias=attn_bias, attention_mask=attention_mask,
+        //  is_causal=is_causal)
+        {
+
+            // compute QKV
+            { cur = ggml_mul_mat(ctx0, model.layers[il].c_attn_wqkv_weight, cur); }
+
+            struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0 * sizeof(float) * n_embd);
+            struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1 * sizeof(float) * n_embd);
+            struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2 * sizeof(float) * n_embd);
+
+            // store key and value to memory
+            {
+                struct ggml_tensor * k =
+                    ggml_view_1d(ctx0, model.memory_k, N * n_embd,
+                                 (ggml_element_size(model.memory_k) * n_embd) * (il * n_ctx + n_past));
+                struct ggml_tensor * v =
+                    ggml_view_1d(ctx0, model.memory_v, N * n_embd,
+                                 (ggml_element_size(model.memory_v) * n_embd) * (il * n_ctx + n_past));
+
+                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
+            }
+
+            // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0,
+            // 2, 1, 3) [64, N, 12]
+            struct ggml_tensor * Q = ggml_permute(
+                ctx0, ggml_cpy(ctx0, Qcur, ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd / n_head, n_head, N)), 0, 2,
+                1, 3);
+
+            // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1,
+            // 3) [64, n_past + N, 12]
+            struct ggml_tensor * K =
+                ggml_permute(ctx0,
+                             ggml_reshape_3d(ctx0,
+                                             ggml_view_1d(ctx0, model.memory_k, (n_past + N) * n_embd,
+                                                          il * n_ctx * ggml_element_size(model.memory_k) * n_embd),
+                                             n_embd / n_head, n_head, n_past + N),
+                             0, 2, 1, 3);
+            // K * Q
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+
+            // KQ_scaled = KQ / sqrt(n_embd/n_head)
+            struct ggml_tensor * KQ_scaled =
+                ggml_scale(ctx0, KQ, ggml_new_f32(ctx0, 1.0f / sqrt(float(n_embd) / n_head)));
+
+            struct ggml_tensor * KQ_scaled_alibi = ggml_alibi(ctx0, KQ_scaled, n_past, n_head, 8.0);
+
+            // KQ_masked = mask_past(KQ_scaled)
+            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past);
+
+            // KQ = soft_max(KQ_masked)
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+
+            // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1,
+            // 2, 0, 3).contiguous() [n_past + N, 64, 12]
+            struct ggml_tensor * V_trans = ggml_cpy(
+                ctx0,
+                ggml_permute(ctx0,
+                             ggml_reshape_3d(ctx0,
+                                             ggml_view_1d(ctx0, model.memory_v, (n_past + N) * n_embd,
+                                                          il * n_ctx * ggml_element_size(model.memory_v) * n_embd),
+                                             n_embd / n_head, n_head, n_past + N),
+                             1, 2, 0, 3),
+                ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd / n_head, n_head));
+
+            // KQV = transpose(V) * KQ_soft_max
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+
+            // KQV_merged = KQV.permute(0, 2, 1, 3)
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+            // cur = KQV_merged.contiguous().view(n_embd, N)
+            cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+
+            // projection
+            { cur = ggml_mul_mat(ctx0, model.layers[il].c_attn_out_proj_weight, cur); }
+        }
+
+        inpL = ggml_add(ctx0, inpL, cur);
+
+        // m = self.ln_2(x)
+        {
+            cur = ggml_norm(ctx0, inpL);
+
+            cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ln_2_weight, cur), cur);
+        }
+
+        // n = self.mlp(m)
+        {
+
+            cur = ggml_mul_mat(ctx0, model.layers[il].c_mlp_mlp_up_weight, cur);
+
+            // GELU activation
+            cur = ggml_gelu(ctx0, cur);
+
+            // projection
+            // cur = proj_w*cur + proj_b
+            cur = ggml_mul_mat(ctx0, model.layers[il].c_mlp_mlp_down_weight, cur);
+        }
+
+        // x = x + n
+        inpL = ggml_add(ctx0, inpL, cur);
+    }
+
+    // norm
+    {
+        inpL = ggml_norm(ctx0, inpL);
+        // inpL = ln_f_g*inpL
+        inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.ln_f_weight, inpL), inpL);
+    }
+
+    // output embedding weight tied to input embedding
+    inpL = ggml_mul_mat(ctx0, model.wte_weight, inpL);
+
+    // logits -> probs
+    // inpL = ggml_soft_max(ctx0, inpL);
+
+    // run the computation
+    ggml_build_forward_expand(&gf, inpL);
+    ggml_graph_compute(ctx0, &gf);
+
+    // std::cout << "Qcur" << std::endl;
+    // print_tensor(Qcur);
+
+    // if (n_past%100 == 0) {
+    // ggml_graph_print(&gf);
+    // ggml_graph_dump_dot(&gf, NULL, "replit-model.dot");
+    // }
+
+    // return result for just the last token
+    embd_w.resize(n_vocab);
+    memcpy(embd_w.data(), (float *)ggml_get_data(inpL) + (n_vocab * (N - 1)), sizeof(float) * n_vocab);
+
+    if (mem_per_token == 0) {
+        mem_per_token = ggml_used_mem(ctx0) / N;
+    }
+    // printf("used_mem = %zu\n", ggml_used_mem(ctx0));
+
+    ggml_free(ctx0);
+
+    return true;
+}
+
+int main(int argc, char ** argv) {
+    const int64_t t_main_start_us = ggml_time_us();
+
+    gpt_params params;
+    params.model = "";
+
+    if (gpt_params_parse(argc, argv, params) == false) {
+        return 1;
+    }
+
+    if (params.seed < 0) {
+        params.seed = time(NULL);
+    }
+
+    printf("%s: seed = %d\n", __func__, params.seed);
+
+    std::mt19937 rng(params.seed);
+    if (params.prompt.empty()) {
+        if (!isatty(STDIN_FILENO)) {
+            std::string line;
+            while (std::getline(std::cin, line)) {
+                params.prompt = params.prompt + "\n" + line;
+            }
+        } else {
+            params.prompt = gpt_random_prompt(rng);
+        }
+    }
+
+    int64_t t_load_us = 0;
+
+    replit_tokenizer vocab;
+    replit_model model;
+
+    // load the model
+    {
+        const int64_t t_start_us = ggml_time_us();
+
+        if (!replit_model_load(params.model, model, vocab)) {
+            fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
+            return 1;
+        }
+
+        t_load_us = ggml_time_us() - t_start_us;
+    }
+
+    int n_past = 0;
+
+    int64_t t_sample_us = 0;
+    int64_t t_predict_us = 0;
+
+    std::vector<float> logits;
+
+    // tokenize the prompt
+    std::vector<std::size_t> embd_inp = replit_tokenizer_tokenize(vocab, params.prompt);
+
+    printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
+
+    for (int i = 0; i < embd_inp.size(); i++) {
+        printf("%s: token[%d] = %6lu\n", __func__, i, embd_inp[i]);
+        // vocab.id_to_token.at(embd_inp[i]).c_str()
+    }
+    printf("\n");
+
+    params.n_predict = std::min(params.n_predict, model.hparams.max_seq_len - (int)embd_inp.size());
+
+    std::vector<gpt_vocab::id> embd;
+
+    // determine the required inference memory per token:
+    size_t mem_per_token = 0;
+    replit_eval(model, params.n_threads, 0, {0, 1, 2, 3}, logits, mem_per_token);
+
+    for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
+        // predict
+        if (embd.size() > 0) {
+            const int64_t t_start_us = ggml_time_us();
+
+            if (!replit_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
+                printf("Failed to predict\n");
+                return 1;
+            }
+
+            t_predict_us += ggml_time_us() - t_start_us;
+        }
+
+        n_past += embd.size();
+        embd.clear();
+
+        if (i >= embd_inp.size()) {
+            // sample next token
+            const int top_k = params.top_k;
+            const float top_p = params.top_p;
+            const float temp = params.temp;
+
+            const int n_vocab = model.hparams.n_vocab;
+
+            gpt_vocab::id id = 0;
+
+            {
+                const int64_t t_start_sample_us = ggml_time_us();
+
+                id = gpt_sample_top_k_top_p(vocab.raw_vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p,
+                                            temp, rng);
+
+                t_sample_us += ggml_time_us() - t_start_sample_us;
+            }
+
+            // add it to the context
+            embd.push_back(id);
+        } else {
+            // if here, it means we are still processing the input prompt
+            for (int k = i; k < embd_inp.size(); k++) {
+                embd.push_back(embd_inp[k]);
+                if (embd.size() > params.n_batch) {
+                    break;
+                }
+            }
+            i += embd.size() - 1;
+        }
+
+        // display text
+        for (auto id : embd) {
+            printf("%s", replit_tokenizer_detokenize(vocab, {static_cast<std::size_t>(id)}).c_str());
+        }
+        fflush(stdout);
+
+        // end of text token
+        if (embd.back() == 0) {
+            break;
+        }
+    }
+
+    // report timing
+    {
+        const int64_t t_main_end_us = ggml_time_us();
+
+        printf("\n\n");
+        printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
+        printf("%s:     load time = %8.2f ms\n", __func__, t_load_us / 1000.0f);
+        printf("%s:   sample time = %8.2f ms\n", __func__, t_sample_us / 1000.0f);
+        printf("%s:  predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us / 1000.0f,
+               t_predict_us / 1000.0f / n_past);
+        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f);
+    }
+
+    ggml_free(model.ctx);
+
+    return 0;
+}
diff --git a/examples/replit/quantize.cpp b/examples/replit/quantize.cpp
new file mode 100644 (file)
index 0000000..40a5806
--- /dev/null
@@ -0,0 +1,176 @@
+#include "ggml/ggml.h"
+
+#include "common-ggml.h"
+#include "common.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <regex>
+#include <string>
+#include <vector>
+
+struct mpt_hparams {
+    int32_t d_model = 0;
+    int32_t max_seq_len = 0;
+    int32_t n_heads = 0;
+    int32_t n_layers = 0;
+    int32_t n_vocab = 0;
+    int32_t ftype = 0;
+};
+
+// quantize a model
+bool mpt_model_quantize(const std::string & fname_inp,
+                        const std::string & fname_out, ggml_ftype ftype) {
+
+    printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
+
+    auto finp = std::ifstream(fname_inp, std::ios::binary);
+    if (!finp) {
+        fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__,
+                fname_inp.c_str());
+        return false;
+    }
+
+    auto fout = std::ofstream(fname_out, std::ios::binary);
+    if (!fout) {
+        fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__,
+                fname_out.c_str());
+        return false;
+    }
+
+    // verify magic
+    {
+        uint32_t magic;
+        finp.read((char *)&magic, sizeof(magic));
+        if (magic != 0x67676d6c) {
+            fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n",
+                    __func__, fname_inp.c_str());
+            return false;
+        }
+
+        fout.write((char *)&magic, sizeof(magic));
+    }
+
+    mpt_hparams hparams;
+
+    // load hparams
+    {
+        finp.read((char *)&hparams.d_model, sizeof(hparams.d_model));
+        finp.read((char *)&hparams.max_seq_len, sizeof(hparams.max_seq_len));
+        finp.read((char *)&hparams.n_heads, sizeof(hparams.n_heads));
+        finp.read((char *)&hparams.n_layers, sizeof(hparams.n_layers));
+        finp.read((char *)&hparams.n_vocab, sizeof(hparams.n_vocab));
+        finp.read((char *)&hparams.ftype, sizeof(hparams.ftype));
+
+        printf("%s: d_model = %d\n", __func__, hparams.d_model);
+        printf("%s: max_seq_len = %d\n", __func__, hparams.max_seq_len);
+        printf("%s: n_heads = %d\n", __func__, hparams.n_heads);
+        printf("%s: n_layers = %d\n", __func__, hparams.n_layers);
+        printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
+        printf("%s: ftype = %d\n", __func__, hparams.ftype);
+
+        fout.write((char *)&hparams.d_model, sizeof(hparams.d_model));
+        fout.write((char *)&hparams.max_seq_len, sizeof(hparams.max_seq_len));
+        fout.write((char *)&hparams.n_heads, sizeof(hparams.n_heads));
+        fout.write((char *)&hparams.n_layers, sizeof(hparams.n_layers));
+        fout.write((char *)&hparams.n_vocab, sizeof(hparams.n_vocab));
+        fout.write((char *)&ftype, sizeof(hparams.ftype));
+    }
+
+    // load vocab
+    {
+        const int32_t n_vocab = hparams.n_vocab;
+
+        std::string word;
+        for (int i = 0; i < n_vocab; i++) {
+            uint32_t len;
+            finp.read((char *)&len, sizeof(len));
+            fout.write((char *)&len, sizeof(len));
+
+            word.resize(len);
+            finp.read((char *)word.data(), len);
+            fout.write((char *)word.data(), len);
+
+            float prob;
+            finp.read((char *)&prob, sizeof(prob));
+            fout.write((char *)&prob, sizeof(prob));
+        }
+    }
+
+    printf("%s: quantizing tensors\n", __func__);
+
+    // regexes of tensor names to be quantized
+    const std::vector<std::string> to_quant = {
+        ".*weight",
+    };
+
+    if (!ggml_common_quantize_0(finp, fout, ftype, to_quant, {})) {
+        fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__,
+                fname_inp.c_str());
+        return false;
+    }
+
+    finp.close();
+    fout.close();
+
+    return true;
+}
+
+// usage:
+//  ./replit-quantize models/replit/ggml-model.bin
+//  models/replit/ggml-model-quant.bin type
+//
+int main(int argc, char ** argv) {
+    if (argc != 4) {
+        fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n",
+                argv[0]);
+        ggml_print_ftypes(stderr);
+        return 1;
+    }
+
+    // needed to initialize f16 tables
+    {
+        struct ggml_init_params params = {0, NULL, false};
+        struct ggml_context * ctx = ggml_init(params);
+        ggml_free(ctx);
+    }
+
+    const std::string fname_inp = argv[1];
+    const std::string fname_out = argv[2];
+
+    const ggml_ftype ftype = ggml_parse_ftype(argv[3]);
+
+    const int64_t t_main_start_us = ggml_time_us();
+
+    int64_t t_quantize_us = 0;
+
+    // load the model
+    {
+        const int64_t t_start_us = ggml_time_us();
+
+        if (!mpt_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) {
+            fprintf(stderr, "%s: failed to quantize model from '%s'\n",
+                    __func__, fname_inp.c_str());
+            return 1;
+        }
+
+        t_quantize_us = ggml_time_us() - t_start_us;
+    }
+
+    // report timing
+    {
+        const int64_t t_main_end_us = ggml_time_us();
+
+        printf("\n");
+        printf("%s: quantize time = %8.2f ms\n", __func__,
+               t_quantize_us / 1000.0f);
+        printf("%s:    total time = %8.2f ms\n", __func__,
+               (t_main_end_us - t_main_start_us) / 1000.0f);
+    }
+
+    return 0;
+}
index 2749b9aface0907b4953e89c48ca9cb986443c1f..3269218bd261b6f606c23a4fcfffae13ce64b807 100644 (file)
@@ -313,6 +313,7 @@ extern "C" {
         GGML_OP_ROPE,
         GGML_OP_ROPE_BACK,
         GGML_OP_ALIBI,
+        GGML_OP_CLAMP,
         GGML_OP_CONV_1D_1S,
         GGML_OP_CONV_1D_2S,
 
@@ -897,7 +898,16 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             int                   n_past,
-            int                   n_head);
+            int                   n_head,
+            float                 bias_max);
+
+    // clamp
+    // in-place, returns view(a)
+    struct ggml_tensor * ggml_clamp(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 min,
+            float                 max);
 
     // padding = 1
     // TODO: we don't support extra parameters for now
index da3d914e4ef47837fb5455676e61e95a3aef2579..5e6a9a0d462995a193006b4fad17474018717170 100644 (file)
@@ -3457,6 +3457,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "ROPE",
     "ROPE_BACK",
     "ALIBI",
+    "CLAMP",
     "CONV_1D_1S",
     "CONV_1D_2S",
 
@@ -3467,7 +3468,8 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "MAP_BINARY",
 };
 
-static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50");
+static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
+
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -3517,6 +3519,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "rope(x)",
     "rope_back(x)",
     "alibi(x)",
+    "clamp(x)",
     "conv_1d_1s(x)",
     "conv_1d_2s(x)",
 
@@ -3527,7 +3530,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "f(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50");
+static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -6189,7 +6192,8 @@ struct ggml_tensor * ggml_alibi(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         int                   n_past,
-        int                   n_head) {
+        int                   n_head,
+        float                 bias_max) {
     GGML_ASSERT(n_past >= 0);
     bool is_node = false;
 
@@ -6208,6 +6212,9 @@ struct ggml_tensor * ggml_alibi(
 
     ((int32_t *) b->data)[0] = n_past;
     ((int32_t *) b->data)[1] = n_head;
+    GGML_ASSERT(sizeof(float) == sizeof(int32_t));
+    (((float *) b->data)[2]) = bias_max;
+    
 
     ggml_scratch_load(ctx);
 
@@ -6219,6 +6226,36 @@ struct ggml_tensor * ggml_alibi(
     return result;
 }
 
+// ggml_alibi
+
+struct ggml_tensor * ggml_clamp(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 min,
+        float                 max) {
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    // TODO: when implement backward, fix this:
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
+    ((float *) b->data)[0] = min;
+    ((float *) b->data)[1] = max;
+    
+
+    result->op   = GGML_OP_CLAMP;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
 // ggml_conv_1d_1s
 
 struct ggml_tensor * ggml_conv_1d_1s(
@@ -10682,7 +10719,7 @@ static void ggml_compute_forward_alibi_f32(
         struct ggml_tensor * dst) {
     assert(params->ith == 0);
     assert(src1->type == GGML_TYPE_I32);
-    assert(ggml_nelements(src1) == 2);
+    assert(ggml_nelements(src1) == 3);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -10690,6 +10727,7 @@ static void ggml_compute_forward_alibi_f32(
 
     const int n_past = ((int32_t *) src1->data)[0];
     const int n_head = ((int32_t *) src1->data)[1];
+    const float max_bias = ((float *) src1->data)[2];
 
     assert(n_past >= 0);
 
@@ -10712,8 +10750,8 @@ static void ggml_compute_forward_alibi_f32(
     // add alibi to src0 (KQ_scaled)
     const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
 
-    const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor);
-    const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor);
+    const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
 
     for (int i = 0; i < ne0; i++) {
         for (int j = 0; j < ne1; j++) {
@@ -10731,7 +10769,8 @@ static void ggml_compute_forward_alibi_f32(
                     m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
                 }
 
-                pdst[0] = i * m_k + src[0];
+                pdst[0] = (i-ne0+1) * m_k + src[0];
+
             }
         }
     }
@@ -10745,7 +10784,7 @@ static void ggml_compute_forward_alibi_f16(
         struct ggml_tensor * dst) {
     assert(params->ith == 0);
     assert(src1->type == GGML_TYPE_I32);
-    assert(ggml_nelements(src1) == 2);
+    assert(ggml_nelements(src1) == 3);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -10753,6 +10792,7 @@ static void ggml_compute_forward_alibi_f16(
 
     const int n_past = ((int32_t *) src1->data)[0];
     const int n_head = ((int32_t *) src1->data)[1];
+    const float max_bias = ((float *) src1->data)[2];
 
     assert(n_past >= 0);
 
@@ -10775,8 +10815,8 @@ static void ggml_compute_forward_alibi_f16(
     // add alibi to src0 (KQ_scaled)
     const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
 
-    const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor);
-    const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor);
+    const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
 
     for (int i = 0; i < ne0; i++) {
         for (int j = 0; j < ne1; j++) {
@@ -10795,7 +10835,7 @@ static void ggml_compute_forward_alibi_f16(
                 }
 
                 // we return F32
-                pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
+                pdst[0] = (i-ne0+1) * m_k + GGML_FP16_TO_FP32(src[0]);
             }
         }
     }
@@ -10831,6 +10871,79 @@ static void ggml_compute_forward_alibi(
     }
 }
 
+
+// ggml_compute_forward_alibi
+
+static void ggml_compute_forward_clamp_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(src1->type == GGML_TYPE_I32);
+    assert(ggml_nelements(src1) == 2);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int min = ((float *) src1->data)[0];
+    const int max = ((float *) src1->data)[1];
+
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    for (int j = ith; j < n; j += nth) {
+        float * dst_ptr  = (float *) ((char *) dst->data  + j*nb1);
+        float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
+        for (int i = 0; i < nc; i++) {
+
+            dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
+        }
+    }
+}
+
+
+static void ggml_compute_forward_clamp(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_clamp_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F16:
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_rope
 
 static void ggml_compute_forward_rope_f32(
@@ -12812,6 +12925,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor);
             } break;
+        case GGML_OP_CLAMP:
+            {
+                ggml_compute_forward_clamp(params, tensor->src0, tensor->src1, tensor);
+            } break;
         case GGML_OP_CONV_1D_1S:
             {
                 ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor);
@@ -13119,6 +13236,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_CLAMP:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_SILU:
             {
                 // necessary for llama
@@ -13998,6 +14119,10 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     {
                         node->n_tasks = 1; //TODO
                     } break;
+                case GGML_OP_CLAMP:
+                    {
+                        node->n_tasks = 1; //TODO
+                    } break;
                 case GGML_OP_CONV_1D_1S:
                 case GGML_OP_CONV_1D_2S:
                     {