From: Lukas Möller Date: Wed, 17 May 2023 19:58:21 +0000 (+0200) Subject: examples : sample replit + MPT inference (#145) X-Git-Tag: upstream/0.0.1642~1465 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=1d6a133b098e3b038a7e2c544941b5288ca76bc8;p=pkg%2Fggml%2Fsources%2Fggml examples : sample replit + MPT inference (#145) * Add replit model * Add unigram tokenization support * Remove debug log * Port alibi attn bias fix * Remove torch input * Fix hardcoded path * Remove unsupported hyperparams * Add mpt * Add replit quantization script * Remove debug print * Add quantization support to mpt * Reformat * Remove trailing return type * Implement stylistic changes * use f16 in k/v memory calculations for replit/mpt * Update context size calculation * Add clip_qkv and alibi_bias_max support * fix clamping implementation, remove implicit conversions * Fix qkv if condition * Fix replit context size calculation * Potentially fix gcc compilation error * Fix warning * Adjust object overhead * Remove dead code --- diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index ceca69e9..7a4bb246 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -23,4 +23,6 @@ add_subdirectory(whisper) add_subdirectory(mnist) add_subdirectory(gpt-neox) add_subdirectory(dolly-v2) +add_subdirectory(replit) +add_subdirectory(mpt) add_subdirectory(starcoder) diff --git a/examples/mpt/CMakeLists.txt b/examples/mpt/CMakeLists.txt new file mode 100644 index 00000000..09408f9f --- /dev/null +++ b/examples/mpt/CMakeLists.txt @@ -0,0 +1,13 @@ +# +# mpt + +set(TEST_TARGET mpt) +add_executable(${TEST_TARGET} main.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml) + +# +# mpt-quantize + +set(TEST_TARGET mpt-quantize) +add_executable(${TEST_TARGET} quantize.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml) diff --git a/examples/mpt/convert-h5-to-ggml.py b/examples/mpt/convert-h5-to-ggml.py new file mode 100644 index 00000000..9bff9d3f --- /dev/null +++ b/examples/mpt/convert-h5-to-ggml.py @@ -0,0 +1,111 @@ +import sys +import struct +import json +import numpy as np +from transformers import AutoModelForCausalLM, AutoTokenizer +import sentencepiece.sentencepiece_model_pb2 as model + +if len(sys.argv) < 3: + print("Usage: convert-h5-to-ggml.py dir-model [use-f32]\n") + print(" ftype == 0 -> float32") + print(" ftype == 1 -> float16") + sys.exit(1) + + +# output in the same directory as the model +dir_model = sys.argv[1] +fname_out = sys.argv[1] + "/ggml-model.bin" + + +with open(dir_model + "/config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + +# possible data types +# ftype == 0 -> float32 +# ftype == 1 -> float16 +# +# map from ftype to string +ftype_str = ["f32", "f16"] + +ftype = 1 +if len(sys.argv) > 2: + ftype = int(sys.argv[2]) + if ftype < 0 or ftype > 1: + print("Invalid ftype: " + str(ftype)) + sys.exit(1) + fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + + +tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained( + dir_model, low_cpu_mem_usage=True, trust_remote_code=True +) +# print (model) + +# print(tokenizer.encode('I believe the meaning of life is')) + +list_vars = model.state_dict() +for name in list_vars.keys(): + print(name, list_vars[name].shape, list_vars[name].dtype) + +fout = open(fname_out, "wb") + +print(hparams) + +fout.write(struct.pack("i", 0x67676D6C)) # magic: ggml in hex +fout.write(struct.pack("i", hparams["d_model"])) +fout.write(struct.pack("i", hparams["max_seq_len"])) +fout.write(struct.pack("i", hparams["n_heads"])) +fout.write(struct.pack("i", hparams["n_layers"])) +fout.write(struct.pack("i", hparams["vocab_size"])) +fout.write(struct.pack("f", hparams["attn_config"]["alibi_bias_max"])) +fout.write(struct.pack("f", hparams["attn_config"]["clip_qkv"] or 0.0)) +fout.write(struct.pack("i", ftype)) + + +# TODO: temporary hack to not deal with implementing the tokenizer +dot_token = tokenizer.encode(".")[0] +for i in range(hparams["vocab_size"]): + text = tokenizer.decode([dot_token, i]).encode("utf-8") + # remove the first byte (it's always '.') + text = text[1:] + fout.write(struct.pack("i", len(text))) + fout.write(text) + +for name in list_vars.keys(): + data = list_vars[name].squeeze().numpy() + print("Processing variable: " + name + " with shape: ", data.shape) + + n_dims = len(data.shape) + + # ftype == 0 -> float32, ftype == 1 -> float16 + ftype_cur = 0 + if ftype != 0: + if name[-7:] == ".weight" and n_dims == 2: + print(" Converting to float16") + data = data.astype(np.float16) + ftype_cur = 1 + else: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + else: + if data.dtype != np.float32: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + + # header + str = name.encode("utf-8") + fout.write(struct.pack("iii", n_dims, len(str), ftype_cur)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + fout.write(str) + + # data + data.tofile(fout) + +fout.close() + +print("Done. Output file: " + fname_out) +print("") diff --git a/examples/mpt/main.cpp b/examples/mpt/main.cpp new file mode 100644 index 00000000..5a60367a --- /dev/null +++ b/examples/mpt/main.cpp @@ -0,0 +1,679 @@ +#include "ggml/ggml.h" + +#include "common-ggml.h" +#include "common.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +int n_ctx = 4096; + +// no defaults for now +struct mpt_hparams { + int32_t d_model = 0; + int32_t max_seq_len = 0; + int32_t n_heads = 0; + int32_t n_layers = 0; + int32_t n_vocab = 0; + float alibi_bias_max = 0; + float clip_qkv = 0; + int32_t ftype = 0; +}; + +struct mpt_layer { + // pre normalization + struct ggml_tensor * norm_1_weight; + + // attention + struct ggml_tensor * c_attn_wqkv_weight; + struct ggml_tensor * c_attn_out_proj_weight; + + // post normalization + struct ggml_tensor * norm_2_weight; + + // ff + struct ggml_tensor * ffn_up_proj; + struct ggml_tensor * ffn_down_proj; +}; + +struct mpt_model { + mpt_hparams hparams; + + struct ggml_tensor * wte_weight; // position embedding + struct ggml_tensor * norm_f_weight; // language model head + + std::vector layers; + + // key + value memory + struct ggml_tensor * memory_k; + struct ggml_tensor * memory_v; + + struct ggml_context * ctx; + std::map tensors; +}; + +// load the model's weights from a file +bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vocab) { + printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); + + auto fin = std::ifstream(fname, std::ios::binary); + if (!fin) { + fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); + return false; + } + + // verify magic + { + uint32_t magic; + fin.read((char *)&magic, sizeof(magic)); + if (magic != 0x67676d6c) { + fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); + return false; + } + } + + // load hparams + { + auto & hparams = model.hparams; + + fin.read((char *)&hparams.d_model, sizeof(hparams.d_model)); + fin.read((char *)&hparams.max_seq_len, sizeof(hparams.max_seq_len)); + fin.read((char *)&hparams.n_heads, sizeof(hparams.n_heads)); + fin.read((char *)&hparams.n_layers, sizeof(hparams.n_layers)); + fin.read((char *)&hparams.n_vocab, sizeof(hparams.n_vocab)); + fin.read((char *)&hparams.alibi_bias_max, sizeof(hparams.alibi_bias_max)); + fin.read((char *)&hparams.clip_qkv, sizeof(hparams.clip_qkv)); + fin.read((char *)&hparams.ftype, sizeof(hparams.ftype)); + + printf("%s: d_model = %d\n", __func__, hparams.d_model); + printf("%s: max_seq_len = %d\n", __func__, hparams.max_seq_len); + printf("%s: n_heads = %d\n", __func__, hparams.n_heads); + printf("%s: n_layers = %d\n", __func__, hparams.n_layers); + printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + printf("%s: alibi_bias_max = %f\n", __func__, hparams.alibi_bias_max); + printf("%s: clip_qkv = %f\n", __func__, hparams.clip_qkv); + printf("%s: ftype = %d\n", __func__, hparams.ftype); + } + + // load vocab + { + int32_t n_vocab = model.hparams.n_vocab; + + std::string word; + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + fin.read((char *)&len, sizeof(len)); + + word.resize(len); + fin.read((char *)word.data(), len); + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + } + } + + // for the big tensors, we have the option to store the data in 16-bit + // floats or quantized in order to save memory and also to speed up the + // computation + ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype)(model.hparams.ftype)); + if (wtype == GGML_TYPE_COUNT) { + fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n", __func__, fname.c_str(), + model.hparams.ftype); + return false; + } + + auto & ctx = model.ctx; + + size_t ctx_size = 0; + + { + const auto & hparams = model.hparams; + + const size_t n_embd = hparams.d_model; + const size_t n_layer = hparams.n_layers; + const size_t n_vocab = hparams.n_vocab; + + ctx_size += n_embd * n_vocab * ggml_type_sizef(wtype); // wte_weight + ctx_size += n_embd * ggml_type_sizef(GGML_TYPE_F32); // norm_f_weight + + ctx_size += n_layer * (n_embd * ggml_type_sizef(GGML_TYPE_F32)); // ln_1_weight + ctx_size += n_layer * (3 * n_embd * n_embd * ggml_type_sizef(wtype)); // attn_Wqkv_weight + ctx_size += n_layer * (n_embd * n_embd * ggml_type_sizef(wtype)); // attn_out_proj_weight + ctx_size += n_layer * (n_embd * ggml_type_sizef(GGML_TYPE_F32)); // ln_2_weight + ctx_size += n_layer * (4 * n_embd * n_embd * ggml_type_sizef(wtype)); // mlp_mlp_up_weight + ctx_size += n_layer * (n_embd * n_embd * 4 * ggml_type_sizef(wtype)); // mlp_mlp_down_weight + + ctx_size += n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16); // memory_k + ctx_size += n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16); // memory_v + + ctx_size += (1 + 6 * n_layer) * 512; // object overhead + + printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size / (1024.0 * 1024.0)); + } + + // create the ggml context + { + struct ggml_init_params params = { + .mem_size = ctx_size, + .mem_buffer = NULL, + .no_alloc = false, + }; + + model.ctx = ggml_init(params); + if (!model.ctx) { + fprintf(stderr, "%s: ggml_init() failed\n", __func__); + return false; + } + } + + // prepare memory for the weights + { + const auto & hparams = model.hparams; + + const size_t n_embd = hparams.d_model; + const size_t n_layer = hparams.n_layers; + const size_t n_vocab = hparams.n_vocab; + + model.layers.resize(n_layer); + + model.wte_weight = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab); + model.norm_f_weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + + // map by name + model.tensors["transformer.wte.weight"] = model.wte_weight; + model.tensors["transformer.norm_f.weight"] = model.norm_f_weight; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = model.layers[i]; + + layer.norm_1_weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + layer.c_attn_wqkv_weight = ggml_new_tensor_2d(ctx, wtype, n_embd, 3 * n_embd); + layer.c_attn_out_proj_weight = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.norm_2_weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + layer.ffn_up_proj = ggml_new_tensor_2d(ctx, wtype, n_embd, 4 * n_embd); + layer.ffn_down_proj = ggml_new_tensor_2d(ctx, wtype, 4 * n_embd, n_embd); + + // map by name + model.tensors["transformer.blocks." + std::to_string(i) + ".norm_1.weight"] = layer.norm_1_weight; + model.tensors["transformer.blocks." + std::to_string(i) + ".attn.Wqkv.weight"] = layer.c_attn_wqkv_weight; + model.tensors["transformer.blocks." + std::to_string(i) + ".attn.out_proj.weight"] = + layer.c_attn_out_proj_weight; + model.tensors["transformer.blocks." + std::to_string(i) + ".norm_2.weight"] = layer.norm_2_weight; + model.tensors["transformer.blocks." + std::to_string(i) + ".ffn.up_proj.weight"] = layer.ffn_up_proj; + model.tensors["transformer.blocks." + std::to_string(i) + ".ffn.down_proj.weight"] = layer.ffn_down_proj; + } + } + + // key + value memory + { + const auto & hparams = model.hparams; + + const size_t n_embd = hparams.d_model; + const size_t n_layer = hparams.n_layers; + + const int64_t n_mem = n_layer * n_ctx; + const int64_t n_elements = n_embd * n_mem; + + model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + + const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); + + printf("%s: memory_size = %8.2f MB, n_mem = %lld\n", __func__, memory_size / 1024.0 / 1024.0, n_mem); + } + + // load weights + { + int n_tensors = 0; + size_t total_size = 0; + + printf("%s: ", __func__); + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ttype; + + fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + fin.read(reinterpret_cast(&length), sizeof(length)); + fin.read(reinterpret_cast(&ttype), sizeof(ttype)); + + if (fin.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[2] = {1, 1}; + for (int i = 0; i < n_dims; ++i) { + fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } + + std::string name(length, 0); + fin.read(&name[0], length); + + if (model.tensors.find(name.data()) == model.tensors.end()) { + fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + auto tensor = model.tensors[name.data()]; + if (ggml_nelements(tensor) != nelements) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) { + fprintf(stderr, + "%s: tensor '%s' has wrong shape in model file: got [%5d, " + "%5d], expected [%5d, %5d]\n", + __func__, name.data(), (int)tensor->ne[0], (int)tensor->ne[1], ne[0], ne[1]); + return false; + } + + // for debugging + if (0) { + printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], + ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor) / 1024.0 / 1024.0, ggml_nbytes(tensor)); + } + + const size_t bpe = ggml_type_size(ggml_type(ttype)); + + if ((nelements * bpe) / ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + fprintf(stderr, + "%s: tensor '%s' has wrong size in model file: got %zu, " + "expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements * bpe); + return false; + } + + fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); + + total_size += ggml_nbytes(tensor); + if (++n_tensors % 8 == 0) { + printf("."); + fflush(stdout); + } + } + + printf(" done\n"); + + printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size / 1024.0 / 1024.0, n_tensors); + } + + fin.close(); + + return true; +} + +// evaluate the transformer +// +// - model: the model +// - n_threads: number of threads to use +// - n_past: the context size so far +// - embd_inp: the embeddings of the tokens in the context +// - embd_w: the predicted logits for the next token +// +bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past, + const std::vector & embd_inp, std::vector & embd_w, size_t & mem_per_token) { + const int N = embd_inp.size(); + + const auto & hparams = model.hparams; + + const int n_embd = hparams.d_model; + const int n_layer = hparams.n_layers; + const int n_head = hparams.n_heads; + const int n_vocab = hparams.n_vocab; + + static size_t buf_size = 256u * 1024 * 1024; + static void * buf = malloc(buf_size); + + if (mem_per_token > 0 && mem_per_token * N > buf_size) { + const size_t buf_size_new = 1.1 * (mem_per_token * N); // add 10% to account for ggml object overhead + // printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, + // buf_size, buf_size_new); + + // reallocate + buf_size = buf_size_new; + buf = realloc(buf, buf_size); + if (buf == nullptr) { + fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size); + return false; + } + } + + struct ggml_init_params params = { + .mem_size = buf_size, + .mem_buffer = buf, + .no_alloc = false, + }; + + struct ggml_context * ctx0 = ggml_init(params); + struct ggml_cgraph gf = {.n_threads = n_threads}; + + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + memcpy(embd->data, embd_inp.data(), N * ggml_element_size(embd)); + + struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.wte_weight, embd); + + for (int il = 0; il < n_layer; ++il) { + + struct ggml_tensor * cur; + + // a = self.ln_1(x) + { + cur = ggml_norm(ctx0, inpL); + + cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].norm_1_weight, cur), cur); + } + + // self-attention + // b, _, past_key_value = self.attn(a, past_key_value=past_key_value, + // attn_bias=attn_bias, attention_mask=attention_mask, + // is_causal=is_causal) + { + + // compute QKV + cur = ggml_mul_mat(ctx0, model.layers[il].c_attn_wqkv_weight, cur); + + if (model.hparams.clip_qkv > 0.0f) { + cur = ggml_clamp(ctx0, cur, -model.hparams.clip_qkv, model.hparams.clip_qkv); + } + + struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0 * sizeof(float) * n_embd); + struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1 * sizeof(float) * n_embd); + struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2 * sizeof(float) * n_embd); + + // store key and value to memory + { + struct ggml_tensor * k = + ggml_view_1d(ctx0, model.memory_k, N * n_embd, + (ggml_element_size(model.memory_k) * n_embd) * (il * n_ctx + n_past)); + struct ggml_tensor * v = + ggml_view_1d(ctx0, model.memory_v, N * n_embd, + (ggml_element_size(model.memory_v) * n_embd) * (il * n_ctx + n_past)); + + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); + } + + // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, + // 2, 1, 3) [64, N, 12] + struct ggml_tensor * Q = ggml_permute( + ctx0, ggml_cpy(ctx0, Qcur, ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd / n_head, n_head, N)), 0, 2, + 1, 3); + + // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, + // 3) [64, n_past + N, 12] + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, model.memory_k, (n_past + N) * n_embd, + il * n_ctx * ggml_element_size(model.memory_k) * n_embd), + n_embd / n_head, n_head, n_past + N), + 0, 2, 1, 3); + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + struct ggml_tensor * KQ_scaled = + ggml_scale(ctx0, KQ, ggml_new_f32(ctx0, 1.0f / sqrt(float(n_embd) / n_head))); + + struct ggml_tensor * KQ_scaled_alibi = + ggml_alibi(ctx0, KQ_scaled, n_past, n_head, model.hparams.alibi_bias_max); + + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past); + + // KQ = soft_max(KQ_masked) + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, + // 2, 0, 3).contiguous() [n_past + N, 64, 12] + struct ggml_tensor * V_trans = ggml_cpy( + ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, model.memory_v, (n_past + N) * n_embd, + il * n_ctx * ggml_element_size(model.memory_v) * n_embd), + n_embd / n_head, n_head, n_past + N), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd / n_head, n_head)); + + // KQV = transpose(V) * KQ_soft_max + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + + // projection + { cur = ggml_mul_mat(ctx0, model.layers[il].c_attn_out_proj_weight, cur); } + } + + inpL = ggml_add(ctx0, inpL, cur); + + // m = self.ln_2(x) + { + cur = ggml_norm(ctx0, inpL); + + cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].norm_2_weight, cur), cur); + } + + // n = self.mlp(m) + { + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_up_proj, cur); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + // projection + // cur = proj_w*cur + proj_b + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down_proj, cur); + } + + // x = x + n + inpL = ggml_add(ctx0, inpL, cur); + } + + // norm + { + inpL = ggml_norm(ctx0, inpL); + // inpL = ln_f_g*inpL + inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.norm_f_weight, inpL), inpL); + } + + // output embedding weight tied to input embedding + inpL = ggml_mul_mat(ctx0, model.wte_weight, inpL); + + // logits -> probs + // inpL = ggml_soft_max(ctx0, inpL); + + // run the computation + ggml_build_forward_expand(&gf, inpL); + ggml_graph_compute(ctx0, &gf); + + // std::cout << "Qcur" << std::endl; + // print_tensor(Qcur); + + // if (n_past%100 == 0) { + // ggml_graph_print(&gf); + // ggml_graph_dump_dot(&gf, NULL, "mpt-model.dot"); + // } + + // return result for just the last token + embd_w.resize(n_vocab); + memcpy(embd_w.data(), (float *)ggml_get_data(inpL) + (n_vocab * (N - 1)), sizeof(float) * n_vocab); + + if (mem_per_token == 0) { + mem_per_token = ggml_used_mem(ctx0) / N; + } + // printf("used_mem = %zu\n", ggml_used_mem(ctx0)); + + ggml_free(ctx0); + + return true; +} + +int main(int argc, char ** argv) { + const int64_t t_main_start_us = ggml_time_us(); + + gpt_params params; + params.model = ""; + + if (gpt_params_parse(argc, argv, params) == false) { + return 1; + } + + if (params.seed < 0) { + params.seed = time(NULL); + } + + printf("%s: seed = %d\n", __func__, params.seed); + + std::mt19937 rng(params.seed); + if (params.prompt.empty()) { + if (!isatty(STDIN_FILENO)) { + std::string line; + while (std::getline(std::cin, line)) { + params.prompt = params.prompt + "\n" + line; + } + } else { + params.prompt = gpt_random_prompt(rng); + } + } + + int64_t t_load_us = 0; + + gpt_vocab vocab; + mpt_model model; + + // load the model + { + const int64_t t_start_us = ggml_time_us(); + + if (!mpt_model_load(params.model, model, vocab)) { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); + return 1; + } + + t_load_us = ggml_time_us() - t_start_us; + } + + int n_past = 0; + + int64_t t_sample_us = 0; + int64_t t_predict_us = 0; + + std::vector logits; + + // tokenize the prompt + std::vector embd_inp = ::gpt_tokenize(vocab, params.prompt); + + printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); + + for (int i = 0; i < embd_inp.size(); i++) { + printf("%s: token[%d] = %6d\n", __func__, i, embd_inp[i]); + // vocab.id_to_token.at(embd_inp[i]).c_str() + } + printf("\n"); + + params.n_predict = std::min(params.n_predict, n_ctx - (int)embd_inp.size()); + + std::vector embd; + + // determine the required inference memory per token: + size_t mem_per_token = 0; + mpt_eval(model, params.n_threads, 0, {0, 1, 2, 3}, logits, mem_per_token); + + for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { + // predict + if (embd.size() > 0) { + const int64_t t_start_us = ggml_time_us(); + + if (!mpt_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) { + printf("Failed to predict\n"); + return 1; + } + + t_predict_us += ggml_time_us() - t_start_us; + } + + n_past += embd.size(); + embd.clear(); + + if (i >= embd_inp.size()) { + // sample next token + const int top_k = params.top_k; + const float top_p = params.top_p; + const float temp = params.temp; + + const int n_vocab = model.hparams.n_vocab; + + gpt_vocab::id id = 0; + + { + const int64_t t_start_sample_us = ggml_time_us(); + + id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng); + + t_sample_us += ggml_time_us() - t_start_sample_us; + } + + // add it to the context + embd.push_back(id); + } else { + // if here, it means we are still processing the input prompt + for (int k = i; k < embd_inp.size(); k++) { + embd.push_back(embd_inp[k]); + if (embd.size() > params.n_batch) { + break; + } + } + i += embd.size() - 1; + } + + // display text + for (auto id : embd) { + printf("%s", vocab.id_to_token[id].c_str()); + } + fflush(stdout); + + // end of text token + if (embd.back() == 0) { + break; + } + } + + // report timing + { + const int64_t t_main_end_us = ggml_time_us(); + + printf("\n\n"); + printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token); + printf("%s: load time = %8.2f ms\n", __func__, t_load_us / 1000.0f); + printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us / 1000.0f); + printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us / 1000.0f, + t_predict_us / 1000.0f / n_past); + printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f); + } + + ggml_free(model.ctx); + + return 0; +} diff --git a/examples/mpt/quantize.cpp b/examples/mpt/quantize.cpp new file mode 100644 index 00000000..8f32bdd4 --- /dev/null +++ b/examples/mpt/quantize.cpp @@ -0,0 +1,180 @@ +#include "ggml/ggml.h" + +#include "common-ggml.h" +#include "common.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct mpt_hparams { + int32_t d_model = 0; + int32_t max_seq_len = 0; + int32_t n_heads = 0; + int32_t n_layers = 0; + int32_t n_vocab = 0; + float alibi_bias_max = 0; + float clip_qkv = 0; + int32_t ftype = 0; +}; + +// quantize a model +bool mpt_model_quantize(const std::string & fname_inp, + const std::string & fname_out, ggml_ftype ftype) { + + printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str()); + + auto finp = std::ifstream(fname_inp, std::ios::binary); + if (!finp) { + fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, + fname_inp.c_str()); + return false; + } + + auto fout = std::ofstream(fname_out, std::ios::binary); + if (!fout) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, + fname_out.c_str()); + return false; + } + + // verify magic + { + uint32_t magic; + finp.read((char *)&magic, sizeof(magic)); + if (magic != 0x67676d6c) { + fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", + __func__, fname_inp.c_str()); + return false; + } + + fout.write((char *)&magic, sizeof(magic)); + } + + mpt_hparams hparams; + + // load hparams + { + finp.read((char *)&hparams.d_model, sizeof(hparams.d_model)); + finp.read((char *)&hparams.max_seq_len, sizeof(hparams.max_seq_len)); + finp.read((char *)&hparams.n_heads, sizeof(hparams.n_heads)); + finp.read((char *)&hparams.n_layers, sizeof(hparams.n_layers)); + finp.read((char *)&hparams.n_vocab, sizeof(hparams.n_vocab)); + finp.read((char *)&hparams.alibi_bias_max, sizeof(hparams.alibi_bias_max)); + finp.read((char *)&hparams.clip_qkv, sizeof(hparams.clip_qkv)); + finp.read((char *)&hparams.ftype, sizeof(hparams.ftype)); + + printf("%s: d_model = %d\n", __func__, hparams.d_model); + printf("%s: max_seq_len = %d\n", __func__, hparams.max_seq_len); + printf("%s: n_heads = %d\n", __func__, hparams.n_heads); + printf("%s: n_layers = %d\n", __func__, hparams.n_layers); + printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + printf("%s: alibi_bias_max = %f\n", __func__, hparams.alibi_bias_max); + printf("%s: clip_qkv = %f\n", __func__, hparams.clip_qkv); + printf("%s: ftype = %d\n", __func__, hparams.ftype); + + fout.write((char *)&hparams.d_model, sizeof(hparams.d_model)); + fout.write((char *)&hparams.max_seq_len, sizeof(hparams.max_seq_len)); + fout.write((char *)&hparams.n_heads, sizeof(hparams.n_heads)); + fout.write((char *)&hparams.n_layers, sizeof(hparams.n_layers)); + fout.write((char *)&hparams.n_vocab, sizeof(hparams.n_vocab)); + fout.write((char *)&hparams.alibi_bias_max, sizeof(hparams.alibi_bias_max)); + fout.write((char *)&hparams.clip_qkv, sizeof(hparams.clip_qkv)); + fout.write((char *)&ftype, sizeof(hparams.ftype)); + } + + // load vocab + { + const int32_t n_vocab = hparams.n_vocab; + + std::string word; + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + finp.read((char *)&len, sizeof(len)); + fout.write((char *)&len, sizeof(len)); + + word.resize(len); + finp.read((char *)word.data(), len); + fout.write((char *)word.data(), len); + } + } + + printf("%s: quantizing tensors\n", __func__); + + // regexes of tensor names to be quantized + const std::vector to_quant = { + ".*weight", + }; + + if (!ggml_common_quantize_0(finp, fout, ftype, to_quant, {})) { + fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__, + fname_inp.c_str()); + return false; + } + + finp.close(); + fout.close(); + + return true; +} + +// usage: +// ./mpt-quantize models/mpt/ggml-model.bin +// models/mpt/ggml-model-quant.bin type +// +int main(int argc, char ** argv) { + if (argc != 4) { + fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", + argv[0]); + ggml_print_ftypes(stderr); + return 1; + } + + // needed to initialize f16 tables + { + struct ggml_init_params params = {0, NULL, false}; + struct ggml_context * ctx = ggml_init(params); + ggml_free(ctx); + } + + const std::string fname_inp = argv[1]; + const std::string fname_out = argv[2]; + + const ggml_ftype ftype = ggml_parse_ftype(argv[3]); + + const int64_t t_main_start_us = ggml_time_us(); + + int64_t t_quantize_us = 0; + + // load the model + { + const int64_t t_start_us = ggml_time_us(); + + if (!mpt_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) { + fprintf(stderr, "%s: failed to quantize model from '%s'\n", + __func__, fname_inp.c_str()); + return 1; + } + + t_quantize_us = ggml_time_us() - t_start_us; + } + + // report timing + { + const int64_t t_main_end_us = ggml_time_us(); + + printf("\n"); + printf("%s: quantize time = %8.2f ms\n", __func__, + t_quantize_us / 1000.0f); + printf("%s: total time = %8.2f ms\n", __func__, + (t_main_end_us - t_main_start_us) / 1000.0f); + } + + return 0; +} diff --git a/examples/replit/CMakeLists.txt b/examples/replit/CMakeLists.txt new file mode 100644 index 00000000..696b7f98 --- /dev/null +++ b/examples/replit/CMakeLists.txt @@ -0,0 +1,13 @@ +# +# replit + +set(TEST_TARGET replit) +add_executable(${TEST_TARGET} main.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml) + +# +# replit-quantize + +set(TEST_TARGET replit-quantize) +add_executable(${TEST_TARGET} quantize.cpp) +target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml) diff --git a/examples/replit/convert-h5-to-ggml.py b/examples/replit/convert-h5-to-ggml.py new file mode 100644 index 00000000..310074b1 --- /dev/null +++ b/examples/replit/convert-h5-to-ggml.py @@ -0,0 +1,113 @@ +from pathlib import Path +import sys +import struct +import json +import numpy as np +from transformers import AutoModelForCausalLM, AutoTokenizer +import sentencepiece.sentencepiece_model_pb2 as model + +if len(sys.argv) < 3: + print("Usage: convert-h5-to-ggml.py dir-model [use-f32]\n") + print(" ftype == 0 -> float32") + print(" ftype == 1 -> float16") + sys.exit(1) + + +# output in the same directory as the model +dir_model = sys.argv[1] +fname_out = sys.argv[1] + "/ggml-model.bin" + + +with open(dir_model + "/config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + +sp_proto = model.ModelProto() +sp_proto.ParseFromString(open(Path(sys.argv[1]) / "spiece.model", "rb").read()) + + +# possible data types +# ftype == 0 -> float32 +# ftype == 1 -> float16 +# +# map from ftype to string +ftype_str = ["f32", "f16"] + +ftype = 1 +if len(sys.argv) > 2: + ftype = int(sys.argv[2]) + if ftype < 0 or ftype > 1: + print("Invalid ftype: " + str(ftype)) + sys.exit(1) + fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + + +tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained( + dir_model, low_cpu_mem_usage=True, trust_remote_code=True +) +# print (model) + +# print(tokenizer.encode('I believe the meaning of life is')) + +list_vars = model.state_dict() +for name in list_vars.keys(): + print(name, list_vars[name].shape, list_vars[name].dtype) + +fout = open(fname_out, "wb") + +print(hparams) + +fout.write(struct.pack("i", 0x67676D6C)) # magic: ggml in hex +fout.write(struct.pack("i", hparams["d_model"])) +fout.write(struct.pack("i", hparams["max_seq_len"])) +fout.write(struct.pack("i", hparams["n_heads"])) +fout.write(struct.pack("i", hparams["n_layers"])) +fout.write(struct.pack("i", hparams["vocab_size"])) +fout.write(struct.pack("i", ftype)) + + +# TODO: temporary hack to not deal with implementing the tokenizer +for piece in sp_proto.pieces: + encoded_piece = piece.piece.encode("utf-8") + fout.write(struct.pack("i", len(encoded_piece))) + fout.write(encoded_piece) + fout.write(struct.pack("f", piece.score)) + + +for name in list_vars.keys(): + data = list_vars[name].squeeze().numpy() + print("Processing variable: " + name + " with shape: ", data.shape) + + n_dims = len(data.shape) + + # ftype == 0 -> float32, ftype == 1 -> float16 + ftype_cur = 0 + if ftype != 0: + if name[-7:] == ".weight" and n_dims == 2: + print(" Converting to float16") + data = data.astype(np.float16) + ftype_cur = 1 + else: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + else: + if data.dtype != np.float32: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + + # header + str = name.encode("utf-8") + fout.write(struct.pack("iii", n_dims, len(str), ftype_cur)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + fout.write(str) + + # data + data.tofile(fout) + +fout.close() + +print("Done. Output file: " + fname_out) +print("") diff --git a/examples/replit/main.cpp b/examples/replit/main.cpp new file mode 100644 index 00000000..e01db755 --- /dev/null +++ b/examples/replit/main.cpp @@ -0,0 +1,767 @@ +#include "ggml/ggml.h" + +#include "common-ggml.h" +#include "common.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using piece_t = std::pair; +using piece_map_t = std::unordered_map; + +struct replit_tokenizer { + gpt_vocab raw_vocab; + piece_map_t piece_map; + std::vector vocab; +}; + +std::pair, float> encode_word(const std::string & word, const piece_map_t & model) { + std::vector best_segmentations_starts(word.length() + 1, -1); + best_segmentations_starts[0] = 0; + + std::vector best_segmentations_scores(word.length() + 1, -std::numeric_limits::infinity()); + best_segmentations_scores[0] = 1.0; + + for (int start_idx = 0; start_idx < word.length(); ++start_idx) { + float best_score_at_start = best_segmentations_scores[start_idx]; + for (int end_idx = start_idx + 1; end_idx <= word.length(); ++end_idx) { + std::string token = word.substr(start_idx, end_idx - start_idx); + if (model.count(token) && best_score_at_start != -std::numeric_limits::infinity()) { + float token_score = model.at(token).second; + float score = token_score + best_score_at_start; + if (best_segmentations_scores[end_idx] == -std::numeric_limits::infinity() || + best_segmentations_scores[end_idx] > score) { + best_segmentations_starts[end_idx] = start_idx; + best_segmentations_scores[end_idx] = score; + } + } + } + } + + if (best_segmentations_scores.back() == -std::numeric_limits::infinity()) { + return std::make_pair(std::vector{0}, 0.0f); + } + + float score = best_segmentations_scores.back(); + int start = best_segmentations_starts.back(); + int end = word.length(); + std::vector tokens; + while (start != 0) { + const auto token_id = model.at(word.substr(start, end - start)).first; + tokens.insert(tokens.begin(), token_id); + int next_start = best_segmentations_starts[start]; + end = start; + start = next_start; + } + const auto token_id = model.at(word.substr(start, end - start)).first; + tokens.insert(tokens.begin(), token_id); + return std::make_pair(tokens, score); +} + +bool replit_tokenizer_load(replit_tokenizer & tokenizer, std::istream & fin, int max_vocab_size) { + + for (std::size_t i = 0; i < max_vocab_size; i++) { + + uint32_t len; + fin.read((char *)&len, sizeof(len)); + + std::string word; + word.resize(len); + fin.read((char *)word.data(), len); + + float score; + fin.read((char *)&score, sizeof(score)); + + tokenizer.piece_map[word] = std::make_pair(i, -score); + tokenizer.raw_vocab.id_to_token[i] = word; + } + + return true; +} + +std::string replace_all(const std::string & str, // where to work + const std::string & find, // substitute 'find' + const std::string & replace // by 'replace' +) { + using namespace std; + string result; + size_t find_len = find.size(); + size_t pos, from = 0; + while (string::npos != (pos = str.find(find, from))) { + result.append(str, from, pos - from); + result.append(replace); + from = pos + find_len; + } + result.append(str, from, string::npos); + return result; +} + +std::string ws_symbol = "\342\226\201"; +std::vector replit_tokenizer_tokenize(replit_tokenizer & tokenizer, const std::string & text) { + std::vector tokens; + auto normalized_text = replace_all(text, " ", ws_symbol); + auto tokenized = encode_word(normalized_text, tokenizer.piece_map); + + return tokenized.first; +} + +std::string replit_tokenizer_detokenize(replit_tokenizer & tokenizer, const std::vector & tokens) { + std::string text; + for (auto token : tokens) { + text += tokenizer.raw_vocab.id_to_token[token]; + } + auto denormalized_text = replace_all(text, ws_symbol, " "); + return denormalized_text; +} + +// no defaults for now +struct mpt_hparams { + int32_t d_model = 0; + int32_t max_seq_len = 0; + int32_t n_heads = 0; + int32_t n_layers = 0; + int32_t n_vocab = 0; + int32_t ftype = 0; +}; + +struct replit_layer { + // pre normalization + struct ggml_tensor * ln_1_weight; + + // attention + struct ggml_tensor * c_attn_wqkv_weight; + + struct ggml_tensor * c_attn_out_proj_weight; + + // post normalization + struct ggml_tensor * ln_2_weight; + + // ff + struct ggml_tensor * c_mlp_mlp_up_weight; + + struct ggml_tensor * c_mlp_mlp_down_weight; +}; + +struct replit_model { + mpt_hparams hparams; + + struct ggml_tensor * wte_weight; // position embedding + struct ggml_tensor * ln_f_weight; // language model head + + std::vector layers; + + // key + value memory + struct ggml_tensor * memory_k; + struct ggml_tensor * memory_v; + + struct ggml_context * ctx; + std::map tensors; +}; + +// load the model's weights from a file +bool replit_model_load(const std::string & fname, replit_model & model, replit_tokenizer & vocab) { + printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); + + auto fin = std::ifstream(fname, std::ios::binary); + if (!fin) { + fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); + return false; + } + + // verify magic + { + uint32_t magic; + fin.read((char *)&magic, sizeof(magic)); + if (magic != 0x67676d6c) { + fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); + return false; + } + } + + // load hparams + { + auto & hparams = model.hparams; + + fin.read((char *)&hparams.d_model, sizeof(hparams.d_model)); + fin.read((char *)&hparams.max_seq_len, sizeof(hparams.max_seq_len)); + fin.read((char *)&hparams.n_heads, sizeof(hparams.n_heads)); + fin.read((char *)&hparams.n_layers, sizeof(hparams.n_layers)); + fin.read((char *)&hparams.n_vocab, sizeof(hparams.n_vocab)); + fin.read((char *)&hparams.ftype, sizeof(hparams.ftype)); + + printf("%s: d_model = %d\n", __func__, hparams.d_model); + printf("%s: max_seq_len = %d\n", __func__, hparams.max_seq_len); + printf("%s: n_heads = %d\n", __func__, hparams.n_heads); + printf("%s: n_layers = %d\n", __func__, hparams.n_layers); + printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + printf("%s: ftype = %d\n", __func__, hparams.ftype); + } + + // load vocab + replit_tokenizer_load(vocab, fin, model.hparams.n_vocab); + + // for the big tensors, we have the option to store the data in 16-bit + // floats or quantized in order to save memory and also to speed up the + // computation + ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype)(model.hparams.ftype)); + if (wtype == GGML_TYPE_COUNT) { + fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n", __func__, fname.c_str(), + model.hparams.ftype); + return false; + } + + auto & ctx = model.ctx; + + size_t ctx_size = 0; + + { + const auto & hparams = model.hparams; + + const int n_embd = hparams.d_model; + const int n_layer = hparams.n_layers; + const int n_ctx = hparams.max_seq_len; + const int n_vocab = hparams.n_vocab; + + ctx_size += n_embd * n_vocab * ggml_type_sizef(wtype); // wte_weight + ctx_size += n_embd * ggml_type_sizef(GGML_TYPE_F32); // ln_f_weight + + ctx_size += n_layer * (n_embd * ggml_type_sizef(GGML_TYPE_F32)); // ln_1_weight + ctx_size += n_layer * (3 * n_embd * n_embd * ggml_type_sizef(wtype)); // attn_Wqkv_weight + ctx_size += n_layer * (n_embd * n_embd * ggml_type_sizef(wtype)); // attn_out_proj_weight + ctx_size += n_layer * (n_embd * ggml_type_sizef(GGML_TYPE_F32)); // ln_2_weight + ctx_size += n_layer * (4 * n_embd * n_embd * ggml_type_sizef(wtype)); // mlp_mlp_up_weight + ctx_size += n_layer * (n_embd * n_embd * 4 * ggml_type_sizef(wtype)); // mlp_mlp_down_weight + + ctx_size += n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16); // memory_k + ctx_size += n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16); // memory_v + + ctx_size += (1 + 6 * n_layer) * 512; // object overhead + + printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size / (1024.0 * 1024.0)); + } + + // create the ggml context + { + struct ggml_init_params params = { + .mem_size = ctx_size, + .mem_buffer = NULL, + .no_alloc = false, + }; + + model.ctx = ggml_init(params); + if (!model.ctx) { + fprintf(stderr, "%s: ggml_init() failed\n", __func__); + return false; + } + } + + // prepare memory for the weights + { + const auto & hparams = model.hparams; + + const int n_embd = hparams.d_model; + const int n_layer = hparams.n_layers; + const int n_vocab = hparams.n_vocab; + + model.layers.resize(n_layer); + + model.wte_weight = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab); + model.ln_f_weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + + // map by name + model.tensors["transformer.wte.weight"] = model.wte_weight; + model.tensors["transformer.ln_f.weight"] = model.ln_f_weight; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = model.layers[i]; + + layer.ln_1_weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + layer.c_attn_wqkv_weight = ggml_new_tensor_2d(ctx, wtype, n_embd, 3 * n_embd); + layer.c_attn_out_proj_weight = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); + layer.ln_2_weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + layer.c_mlp_mlp_up_weight = ggml_new_tensor_2d(ctx, wtype, n_embd, 4 * n_embd); + layer.c_mlp_mlp_down_weight = ggml_new_tensor_2d(ctx, wtype, 4 * n_embd, n_embd); + + // map by name + model.tensors["transformer.blocks." + std::to_string(i) + ".ln_1.weight"] = layer.ln_1_weight; + model.tensors["transformer.blocks." + std::to_string(i) + ".attn.Wqkv.weight"] = layer.c_attn_wqkv_weight; + model.tensors["transformer.blocks." + std::to_string(i) + ".attn.out_proj.weight"] = + layer.c_attn_out_proj_weight; + model.tensors["transformer.blocks." + std::to_string(i) + ".ln_2.weight"] = layer.ln_2_weight; + model.tensors["transformer.blocks." + std::to_string(i) + ".mlp.mlp_up.weight"] = layer.c_mlp_mlp_up_weight; + model.tensors["transformer.blocks." + std::to_string(i) + ".mlp.mlp_down.weight"] = + layer.c_mlp_mlp_down_weight; + } + } + + // key + value memory + { + const auto & hparams = model.hparams; + + const int n_embd = hparams.d_model; + const int n_layer = hparams.n_layers; + const int n_ctx = hparams.max_seq_len; + + const int64_t n_mem = n_layer * n_ctx; + const int64_t n_elements = n_embd * n_mem; + + model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + + const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); + + printf("%s: memory_size = %8.2f MB, n_mem = %lld\n", __func__, memory_size / 1024.0 / 1024.0, n_mem); + } + + // load weights + { + int n_tensors = 0; + size_t total_size = 0; + + printf("%s: ", __func__); + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ttype; + + fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + fin.read(reinterpret_cast(&length), sizeof(length)); + fin.read(reinterpret_cast(&ttype), sizeof(ttype)); + + if (fin.eof()) { + break; + } + + int32_t nelements = 1; + int32_t ne[2] = {1, 1}; + for (int i = 0; i < n_dims; ++i) { + fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); + nelements *= ne[i]; + } + + std::string name(length, 0); + fin.read(&name[0], length); + + if (model.tensors.find(name.data()) == model.tensors.end()) { + fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + auto tensor = model.tensors[name.data()]; + if (ggml_nelements(tensor) != nelements) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) { + fprintf(stderr, + "%s: tensor '%s' has wrong shape in model file: got [%5d, " + "%5d], expected [%5d, %5d]\n", + __func__, name.data(), (int)tensor->ne[0], (int)tensor->ne[1], ne[0], ne[1]); + return false; + } + + // for debugging + if (0) { + printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], + ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor) / 1024.0 / 1024.0, ggml_nbytes(tensor)); + } + + const size_t bpe = ggml_type_size(ggml_type(ttype)); + + if ((nelements * bpe) / ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + fprintf(stderr, + "%s: tensor '%s' has wrong size in model file: got %zu, " + "expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements * bpe); + return false; + } + + fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); + + total_size += ggml_nbytes(tensor); + if (++n_tensors % 8 == 0) { + printf("."); + fflush(stdout); + } + } + + printf(" done\n"); + + printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size / 1024.0 / 1024.0, n_tensors); + } + + fin.close(); + + return true; +} + +// evaluate the transformer +// +// - model: the model +// - n_threads: number of threads to use +// - n_past: the context size so far +// - embd_inp: the embeddings of the tokens in the context +// - embd_w: the predicted logits for the next token +// +bool replit_eval(const replit_model & model, const int n_threads, const int n_past, + const std::vector & embd_inp, std::vector & embd_w, size_t & mem_per_token) { + const int N = embd_inp.size(); + + const auto & hparams = model.hparams; + + const int n_embd = hparams.d_model; + const int n_layer = hparams.n_layers; + const int n_ctx = hparams.max_seq_len; + const int n_head = hparams.n_heads; + const int n_vocab = hparams.n_vocab; + + static size_t buf_size = 256u * 1024 * 1024; + static void * buf = malloc(buf_size); + + if (mem_per_token > 0 && mem_per_token * N > buf_size) { + const size_t buf_size_new = 1.1 * (mem_per_token * N); // add 10% to account for ggml object overhead + // printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, + // buf_size, buf_size_new); + + // reallocate + buf_size = buf_size_new; + buf = realloc(buf, buf_size); + if (buf == nullptr) { + fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size); + return false; + } + } + + struct ggml_init_params params = { + .mem_size = buf_size, + .mem_buffer = buf, + .no_alloc = false, + }; + + struct ggml_context * ctx0 = ggml_init(params); + struct ggml_cgraph gf = {.n_threads = n_threads}; + + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + memcpy(embd->data, embd_inp.data(), N * ggml_element_size(embd)); + + struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.wte_weight, embd); + + for (int il = 0; il < n_layer; ++il) { + + struct ggml_tensor * cur; + + // a = self.ln_1(x) + { + cur = ggml_norm(ctx0, inpL); + + cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ln_1_weight, cur), cur); + } + + // self-attention + // b, _, past_key_value = self.attn(a, past_key_value=past_key_value, + // attn_bias=attn_bias, attention_mask=attention_mask, + // is_causal=is_causal) + { + + // compute QKV + { cur = ggml_mul_mat(ctx0, model.layers[il].c_attn_wqkv_weight, cur); } + + struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0 * sizeof(float) * n_embd); + struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1 * sizeof(float) * n_embd); + struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2 * sizeof(float) * n_embd); + + // store key and value to memory + { + struct ggml_tensor * k = + ggml_view_1d(ctx0, model.memory_k, N * n_embd, + (ggml_element_size(model.memory_k) * n_embd) * (il * n_ctx + n_past)); + struct ggml_tensor * v = + ggml_view_1d(ctx0, model.memory_v, N * n_embd, + (ggml_element_size(model.memory_v) * n_embd) * (il * n_ctx + n_past)); + + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); + } + + // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, + // 2, 1, 3) [64, N, 12] + struct ggml_tensor * Q = ggml_permute( + ctx0, ggml_cpy(ctx0, Qcur, ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd / n_head, n_head, N)), 0, 2, + 1, 3); + + // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, + // 3) [64, n_past + N, 12] + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, model.memory_k, (n_past + N) * n_embd, + il * n_ctx * ggml_element_size(model.memory_k) * n_embd), + n_embd / n_head, n_head, n_past + N), + 0, 2, 1, 3); + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + struct ggml_tensor * KQ_scaled = + ggml_scale(ctx0, KQ, ggml_new_f32(ctx0, 1.0f / sqrt(float(n_embd) / n_head))); + + struct ggml_tensor * KQ_scaled_alibi = ggml_alibi(ctx0, KQ_scaled, n_past, n_head, 8.0); + + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past); + + // KQ = soft_max(KQ_masked) + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, + // 2, 0, 3).contiguous() [n_past + N, 64, 12] + struct ggml_tensor * V_trans = ggml_cpy( + ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, model.memory_v, (n_past + N) * n_embd, + il * n_ctx * ggml_element_size(model.memory_v) * n_embd), + n_embd / n_head, n_head, n_past + N), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd / n_head, n_head)); + + // KQV = transpose(V) * KQ_soft_max + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + + // projection + { cur = ggml_mul_mat(ctx0, model.layers[il].c_attn_out_proj_weight, cur); } + } + + inpL = ggml_add(ctx0, inpL, cur); + + // m = self.ln_2(x) + { + cur = ggml_norm(ctx0, inpL); + + cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ln_2_weight, cur), cur); + } + + // n = self.mlp(m) + { + + cur = ggml_mul_mat(ctx0, model.layers[il].c_mlp_mlp_up_weight, cur); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + // projection + // cur = proj_w*cur + proj_b + cur = ggml_mul_mat(ctx0, model.layers[il].c_mlp_mlp_down_weight, cur); + } + + // x = x + n + inpL = ggml_add(ctx0, inpL, cur); + } + + // norm + { + inpL = ggml_norm(ctx0, inpL); + // inpL = ln_f_g*inpL + inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.ln_f_weight, inpL), inpL); + } + + // output embedding weight tied to input embedding + inpL = ggml_mul_mat(ctx0, model.wte_weight, inpL); + + // logits -> probs + // inpL = ggml_soft_max(ctx0, inpL); + + // run the computation + ggml_build_forward_expand(&gf, inpL); + ggml_graph_compute(ctx0, &gf); + + // std::cout << "Qcur" << std::endl; + // print_tensor(Qcur); + + // if (n_past%100 == 0) { + // ggml_graph_print(&gf); + // ggml_graph_dump_dot(&gf, NULL, "replit-model.dot"); + // } + + // return result for just the last token + embd_w.resize(n_vocab); + memcpy(embd_w.data(), (float *)ggml_get_data(inpL) + (n_vocab * (N - 1)), sizeof(float) * n_vocab); + + if (mem_per_token == 0) { + mem_per_token = ggml_used_mem(ctx0) / N; + } + // printf("used_mem = %zu\n", ggml_used_mem(ctx0)); + + ggml_free(ctx0); + + return true; +} + +int main(int argc, char ** argv) { + const int64_t t_main_start_us = ggml_time_us(); + + gpt_params params; + params.model = ""; + + if (gpt_params_parse(argc, argv, params) == false) { + return 1; + } + + if (params.seed < 0) { + params.seed = time(NULL); + } + + printf("%s: seed = %d\n", __func__, params.seed); + + std::mt19937 rng(params.seed); + if (params.prompt.empty()) { + if (!isatty(STDIN_FILENO)) { + std::string line; + while (std::getline(std::cin, line)) { + params.prompt = params.prompt + "\n" + line; + } + } else { + params.prompt = gpt_random_prompt(rng); + } + } + + int64_t t_load_us = 0; + + replit_tokenizer vocab; + replit_model model; + + // load the model + { + const int64_t t_start_us = ggml_time_us(); + + if (!replit_model_load(params.model, model, vocab)) { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); + return 1; + } + + t_load_us = ggml_time_us() - t_start_us; + } + + int n_past = 0; + + int64_t t_sample_us = 0; + int64_t t_predict_us = 0; + + std::vector logits; + + // tokenize the prompt + std::vector embd_inp = replit_tokenizer_tokenize(vocab, params.prompt); + + printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); + + for (int i = 0; i < embd_inp.size(); i++) { + printf("%s: token[%d] = %6lu\n", __func__, i, embd_inp[i]); + // vocab.id_to_token.at(embd_inp[i]).c_str() + } + printf("\n"); + + params.n_predict = std::min(params.n_predict, model.hparams.max_seq_len - (int)embd_inp.size()); + + std::vector embd; + + // determine the required inference memory per token: + size_t mem_per_token = 0; + replit_eval(model, params.n_threads, 0, {0, 1, 2, 3}, logits, mem_per_token); + + for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { + // predict + if (embd.size() > 0) { + const int64_t t_start_us = ggml_time_us(); + + if (!replit_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) { + printf("Failed to predict\n"); + return 1; + } + + t_predict_us += ggml_time_us() - t_start_us; + } + + n_past += embd.size(); + embd.clear(); + + if (i >= embd_inp.size()) { + // sample next token + const int top_k = params.top_k; + const float top_p = params.top_p; + const float temp = params.temp; + + const int n_vocab = model.hparams.n_vocab; + + gpt_vocab::id id = 0; + + { + const int64_t t_start_sample_us = ggml_time_us(); + + id = gpt_sample_top_k_top_p(vocab.raw_vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, + temp, rng); + + t_sample_us += ggml_time_us() - t_start_sample_us; + } + + // add it to the context + embd.push_back(id); + } else { + // if here, it means we are still processing the input prompt + for (int k = i; k < embd_inp.size(); k++) { + embd.push_back(embd_inp[k]); + if (embd.size() > params.n_batch) { + break; + } + } + i += embd.size() - 1; + } + + // display text + for (auto id : embd) { + printf("%s", replit_tokenizer_detokenize(vocab, {static_cast(id)}).c_str()); + } + fflush(stdout); + + // end of text token + if (embd.back() == 0) { + break; + } + } + + // report timing + { + const int64_t t_main_end_us = ggml_time_us(); + + printf("\n\n"); + printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token); + printf("%s: load time = %8.2f ms\n", __func__, t_load_us / 1000.0f); + printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us / 1000.0f); + printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us / 1000.0f, + t_predict_us / 1000.0f / n_past); + printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f); + } + + ggml_free(model.ctx); + + return 0; +} diff --git a/examples/replit/quantize.cpp b/examples/replit/quantize.cpp new file mode 100644 index 00000000..40a58060 --- /dev/null +++ b/examples/replit/quantize.cpp @@ -0,0 +1,176 @@ +#include "ggml/ggml.h" + +#include "common-ggml.h" +#include "common.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct mpt_hparams { + int32_t d_model = 0; + int32_t max_seq_len = 0; + int32_t n_heads = 0; + int32_t n_layers = 0; + int32_t n_vocab = 0; + int32_t ftype = 0; +}; + +// quantize a model +bool mpt_model_quantize(const std::string & fname_inp, + const std::string & fname_out, ggml_ftype ftype) { + + printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str()); + + auto finp = std::ifstream(fname_inp, std::ios::binary); + if (!finp) { + fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, + fname_inp.c_str()); + return false; + } + + auto fout = std::ofstream(fname_out, std::ios::binary); + if (!fout) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, + fname_out.c_str()); + return false; + } + + // verify magic + { + uint32_t magic; + finp.read((char *)&magic, sizeof(magic)); + if (magic != 0x67676d6c) { + fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", + __func__, fname_inp.c_str()); + return false; + } + + fout.write((char *)&magic, sizeof(magic)); + } + + mpt_hparams hparams; + + // load hparams + { + finp.read((char *)&hparams.d_model, sizeof(hparams.d_model)); + finp.read((char *)&hparams.max_seq_len, sizeof(hparams.max_seq_len)); + finp.read((char *)&hparams.n_heads, sizeof(hparams.n_heads)); + finp.read((char *)&hparams.n_layers, sizeof(hparams.n_layers)); + finp.read((char *)&hparams.n_vocab, sizeof(hparams.n_vocab)); + finp.read((char *)&hparams.ftype, sizeof(hparams.ftype)); + + printf("%s: d_model = %d\n", __func__, hparams.d_model); + printf("%s: max_seq_len = %d\n", __func__, hparams.max_seq_len); + printf("%s: n_heads = %d\n", __func__, hparams.n_heads); + printf("%s: n_layers = %d\n", __func__, hparams.n_layers); + printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + printf("%s: ftype = %d\n", __func__, hparams.ftype); + + fout.write((char *)&hparams.d_model, sizeof(hparams.d_model)); + fout.write((char *)&hparams.max_seq_len, sizeof(hparams.max_seq_len)); + fout.write((char *)&hparams.n_heads, sizeof(hparams.n_heads)); + fout.write((char *)&hparams.n_layers, sizeof(hparams.n_layers)); + fout.write((char *)&hparams.n_vocab, sizeof(hparams.n_vocab)); + fout.write((char *)&ftype, sizeof(hparams.ftype)); + } + + // load vocab + { + const int32_t n_vocab = hparams.n_vocab; + + std::string word; + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + finp.read((char *)&len, sizeof(len)); + fout.write((char *)&len, sizeof(len)); + + word.resize(len); + finp.read((char *)word.data(), len); + fout.write((char *)word.data(), len); + + float prob; + finp.read((char *)&prob, sizeof(prob)); + fout.write((char *)&prob, sizeof(prob)); + } + } + + printf("%s: quantizing tensors\n", __func__); + + // regexes of tensor names to be quantized + const std::vector to_quant = { + ".*weight", + }; + + if (!ggml_common_quantize_0(finp, fout, ftype, to_quant, {})) { + fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__, + fname_inp.c_str()); + return false; + } + + finp.close(); + fout.close(); + + return true; +} + +// usage: +// ./replit-quantize models/replit/ggml-model.bin +// models/replit/ggml-model-quant.bin type +// +int main(int argc, char ** argv) { + if (argc != 4) { + fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", + argv[0]); + ggml_print_ftypes(stderr); + return 1; + } + + // needed to initialize f16 tables + { + struct ggml_init_params params = {0, NULL, false}; + struct ggml_context * ctx = ggml_init(params); + ggml_free(ctx); + } + + const std::string fname_inp = argv[1]; + const std::string fname_out = argv[2]; + + const ggml_ftype ftype = ggml_parse_ftype(argv[3]); + + const int64_t t_main_start_us = ggml_time_us(); + + int64_t t_quantize_us = 0; + + // load the model + { + const int64_t t_start_us = ggml_time_us(); + + if (!mpt_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) { + fprintf(stderr, "%s: failed to quantize model from '%s'\n", + __func__, fname_inp.c_str()); + return 1; + } + + t_quantize_us = ggml_time_us() - t_start_us; + } + + // report timing + { + const int64_t t_main_end_us = ggml_time_us(); + + printf("\n"); + printf("%s: quantize time = %8.2f ms\n", __func__, + t_quantize_us / 1000.0f); + printf("%s: total time = %8.2f ms\n", __func__, + (t_main_end_us - t_main_start_us) / 1000.0f); + } + + return 0; +} diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 2749b9af..3269218b 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -313,6 +313,7 @@ extern "C" { GGML_OP_ROPE, GGML_OP_ROPE_BACK, GGML_OP_ALIBI, + GGML_OP_CLAMP, GGML_OP_CONV_1D_1S, GGML_OP_CONV_1D_2S, @@ -897,7 +898,16 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a, int n_past, - int n_head); + int n_head, + float bias_max); + + // clamp + // in-place, returns view(a) + struct ggml_tensor * ggml_clamp( + struct ggml_context * ctx, + struct ggml_tensor * a, + float min, + float max); // padding = 1 // TODO: we don't support extra parameters for now diff --git a/src/ggml.c b/src/ggml.c index da3d914e..5e6a9a0d 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -3457,6 +3457,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "ROPE", "ROPE_BACK", "ALIBI", + "CLAMP", "CONV_1D_1S", "CONV_1D_2S", @@ -3467,7 +3468,8 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "MAP_BINARY", }; -static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50"); +static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51"); + static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3517,6 +3519,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rope(x)", "rope_back(x)", "alibi(x)", + "clamp(x)", "conv_1d_1s(x)", "conv_1d_2s(x)", @@ -3527,7 +3530,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "f(x,y)", }; -static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50"); +static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); @@ -6189,7 +6192,8 @@ struct ggml_tensor * ggml_alibi( struct ggml_context * ctx, struct ggml_tensor * a, int n_past, - int n_head) { + int n_head, + float bias_max) { GGML_ASSERT(n_past >= 0); bool is_node = false; @@ -6208,6 +6212,9 @@ struct ggml_tensor * ggml_alibi( ((int32_t *) b->data)[0] = n_past; ((int32_t *) b->data)[1] = n_head; + GGML_ASSERT(sizeof(float) == sizeof(int32_t)); + (((float *) b->data)[2]) = bias_max; + ggml_scratch_load(ctx); @@ -6219,6 +6226,36 @@ struct ggml_tensor * ggml_alibi( return result; } +// ggml_alibi + +struct ggml_tensor * ggml_clamp( + struct ggml_context * ctx, + struct ggml_tensor * a, + float min, + float max) { + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); + ((float *) b->data)[0] = min; + ((float *) b->data)[1] = max; + + + result->op = GGML_OP_CLAMP; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + // ggml_conv_1d_1s struct ggml_tensor * ggml_conv_1d_1s( @@ -10682,7 +10719,7 @@ static void ggml_compute_forward_alibi_f32( struct ggml_tensor * dst) { assert(params->ith == 0); assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 2); + assert(ggml_nelements(src1) == 3); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -10690,6 +10727,7 @@ static void ggml_compute_forward_alibi_f32( const int n_past = ((int32_t *) src1->data)[0]; const int n_head = ((int32_t *) src1->data)[1]; + const float max_bias = ((float *) src1->data)[2]; assert(n_past >= 0); @@ -10712,8 +10750,8 @@ static void ggml_compute_forward_alibi_f32( // add alibi to src0 (KQ_scaled) const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor); - const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor); + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); for (int i = 0; i < ne0; i++) { for (int j = 0; j < ne1; j++) { @@ -10731,7 +10769,8 @@ static void ggml_compute_forward_alibi_f32( m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); } - pdst[0] = i * m_k + src[0]; + pdst[0] = (i-ne0+1) * m_k + src[0]; + } } } @@ -10745,7 +10784,7 @@ static void ggml_compute_forward_alibi_f16( struct ggml_tensor * dst) { assert(params->ith == 0); assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 2); + assert(ggml_nelements(src1) == 3); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -10753,6 +10792,7 @@ static void ggml_compute_forward_alibi_f16( const int n_past = ((int32_t *) src1->data)[0]; const int n_head = ((int32_t *) src1->data)[1]; + const float max_bias = ((float *) src1->data)[2]; assert(n_past >= 0); @@ -10775,8 +10815,8 @@ static void ggml_compute_forward_alibi_f16( // add alibi to src0 (KQ_scaled) const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor); - const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor); + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); for (int i = 0; i < ne0; i++) { for (int j = 0; j < ne1; j++) { @@ -10795,7 +10835,7 @@ static void ggml_compute_forward_alibi_f16( } // we return F32 - pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]); + pdst[0] = (i-ne0+1) * m_k + GGML_FP16_TO_FP32(src[0]); } } } @@ -10831,6 +10871,79 @@ static void ggml_compute_forward_alibi( } } + +// ggml_compute_forward_alibi + +static void ggml_compute_forward_clamp_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(src1->type == GGML_TYPE_I32); + assert(ggml_nelements(src1) == 2); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int min = ((float *) src1->data)[0]; + const int max = ((float *) src1->data)[1]; + + + const int ith = params->ith; + const int nth = params->nth; + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + for (int j = ith; j < n; j += nth) { + float * dst_ptr = (float *) ((char *) dst->data + j*nb1); + float * src0_ptr = (float *) ((char *) src0->data + j*nb01); + for (int i = 0; i < nc; i++) { + + dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min); + } + } +} + + +static void ggml_compute_forward_clamp( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_clamp_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_rope static void ggml_compute_forward_rope_f32( @@ -12812,6 +12925,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor); } break; + case GGML_OP_CLAMP: + { + ggml_compute_forward_clamp(params, tensor->src0, tensor->src1, tensor); + } break; case GGML_OP_CONV_1D_1S: { ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor); @@ -13119,6 +13236,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_CLAMP: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_SILU: { // necessary for llama @@ -13998,6 +14119,10 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { node->n_tasks = 1; //TODO } break; + case GGML_OP_CLAMP: + { + node->n_tasks = 1; //TODO + } break; case GGML_OP_CONV_1D_1S: case GGML_OP_CONV_1D_2S: {