- [X] Example of RWKV inference [saharNooby/rwkv.cpp](https://github.com/saharNooby/rwkv.cpp)
- [ ] Example of [SAM](https://github.com/facebookresearch/segment-anything) inference
- [ ] Idea for GPU support: https://github.com/ggerganov/llama.cpp/discussions/915
-- [X] Example of StableLM (GPT-NeoX) inference [examples/stablelm](https://github.com/ggerganov/ggml/tree/master/examples/stablelm)
+- [X] Example of StableLM (GPT-NeoX) inference [examples/gpt-neox](https://github.com/ggerganov/ggml/tree/master/examples/stablelm)
- [X] Example of BERT inference [skeskinen/bert.cpp](https://github.com/skeskinen/bert.cpp)
## Whisper inference (example)
add_subdirectory(gpt-j)
add_subdirectory(whisper)
add_subdirectory(mnist)
-add_subdirectory(stablelm)
+add_subdirectory(gpt-neox)
add_subdirectory(dolly-v2)
--- /dev/null
+#
+# gpt-neox
+
+set(TEST_TARGET gpt-neox)
+add_executable(${TEST_TARGET} main.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
+
+#
+# gpt-neox-quantize
+
+set(TEST_TARGET gpt-neox-quantize)
+add_executable(${TEST_TARGET} quantize.cpp)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
--- /dev/null
+# GPT-NeoX
+
+Transformer architecture: GPT-NeoX
+
+Ref: https://github.com/stability-AI/stableLM/#stablelm-alpha
+
+## Usage
+
+```bash
+# get the repo and build it
+git clone https://github.com/ggerganov/ggml
+cd ggml
+mkdir build && cd build
+cmake ..
+make -j
+
+# get the StableLM 3B Alpha model
+git clone https://huggingface.co/stabilityai/gpt_neox-base-alpha-3b
+
+# convert model to FP16
+python3 ../examples/gpt_neox/convert-h5-to-ggml.py ./stablelm-base-alpha-3b/ 1
+
+# run inference using FP16 precision
+make -j && ./bin/gpt_neox -m ./stablelm-base-alpha-3b/ggml-model-f16.bin -p "I believe the meaning of life is" -t 8 -n 64
+
+main: seed = 1681940611
+gpt_neox_model_load: loading model from 'models/stablelm-base-alpha-3b/ggml-model-f16.bin' - please wait ...
+gpt_neox_model_load: n_vocab = 50688
+gpt_neox_model_load: n_ctx = 4096
+gpt_neox_model_load: n_embd = 4096
+gpt_neox_model_load: n_head = 32
+gpt_neox_model_load: n_layer = 16
+gpt_neox_model_load: n_rot = 32
+gpt_neox_model_load: ftype = 1
+gpt_neox_model_load: ggml ctx size = 10011.10 MB
+gpt_neox_model_load: memory_size = 2048.00 MB, n_mem = 65536
+gpt_neox_model_load: ................................ done
+gpt_neox_model_load: model size = 6939.28 MB / num tensors = 260
+main: number of tokens in prompt = 7
+main: token[0] = 42, I
+main: token[1] = 2868, believe
+main: token[2] = 253, the
+main: token[3] = 4495, meaning
+main: token[4] = 273, of
+main: token[5] = 1495, life
+main: token[6] = 310, is
+
+I believe the meaning of life is to grow, to find a way, to love, to find an appreciation for life, and to live it with all of its beauty.
+
+For I am the child of God. I am the offspring of God's love. I am the offspring of the light of the world. I am the offspring of the
+
+main: mem per token = 12186760 bytes
+main: load time = 2118.55 ms
+main: sample time = 9.59 ms
+main: predict time = 4474.07 ms / 63.92 ms per token
+main: total time = 6911.26 ms
+```
+
+## 4-bit integer quantization mode
+
+```bash
+# quantize the model to 4-bits using Q4_3 quantization
+./bin/gpt_neox-quantize ./stablelm-base-alpha-3b/ggml-model-f16.bin ./stablelm-base-alpha-3b/ggml-model-q4_3.bin 6
+
+# run the quantized model
+./bin/gpt_neox -m ./stablelm-base-alpha-3b/ggml-model-q4_3.bin -p "I believe the meaning of life is" -t 8 -n 64
+
+main: seed = 1682021489
+gpt_neox_model_load: loading model from 'models/stablelm-base-alpha-3b/ggml-model-q4_3.bin' - please wait ...
+gpt_neox_model_load: n_vocab = 50688
+gpt_neox_model_load: n_ctx = 4096
+gpt_neox_model_load: n_embd = 4096
+gpt_neox_model_load: n_head = 32
+gpt_neox_model_load: n_layer = 16
+gpt_neox_model_load: n_rot = 32
+gpt_neox_model_load: ftype = 6
+gpt_neox_model_load: ggml ctx size = 5676.10 MB
+gpt_neox_model_load: memory_size = 1024.00 MB, n_mem = 65536
+gpt_neox_model_load: ........................ done
+gpt_neox_model_load: model size = 2604.28 MB / num tensors = 196
+main: number of tokens in prompt = 7
+main: token[0] = 42, I
+main: token[1] = 2868, believe
+main: token[2] = 253, the
+main: token[3] = 4495, meaning
+main: token[4] = 273, of
+main: token[5] = 1495, life
+main: token[6] = 310, is
+
+I believe the meaning of life is to love and be loved. The last three verses were enough to tie us all together. If you love someone you love them all. There are some things in this world that are just not equal in Heaven. - Be here in this moment.
+
+This world is not what is outside of us. It is what
+
+main: mem per token = 12958024 bytes
+main: load time = 850.51 ms
+main: sample time = 9.95 ms
+main: predict time = 3103.81 ms / 44.34 ms per token
+main: total time = 4177.68 ms
+
+```
+
+## Notes
+
+- No guarantees for correctness
+- The tokenizer is currently hacked - probably works only for English
+- Non-parallel residual is not supported
+- Contributions and improvements are welcome
+
+## Note about possible bug
+
+**There might be some issue with this implementation - not 100% sure.
+The embeddings magnitude increases after each layer which is unexpected.
+To observe this, uncomment the following line:**
+
+https://github.com/ggerganov/ggml/blob/abea4b7609c14b837015ab625e3ac36c4708dd03/src/ggml.c#L9208
+
+```
+...
+p[ 0] = 65.5842
+p[ 1] = 61.6951
+p[ 2] = 59.3500
+p[ 3] = 61.2421
+p[ 4] = 65.9653
+p[ 5] = 59.4936
+p[ 6] = 58.4164
+p[ 0] = -209.6351
+p[ 1] = -214.0987
+p[ 2] = -217.0928
+p[ 3] = -215.0267
+p[ 4] = -208.2430
+p[ 5] = -215.3692
+p[ 6] = -214.1981
+p[ 0] = -301.0286
+p[ 1] = -308.6521
+p[ 2] = -310.7513
+p[ 3] = -307.0832
+p[ 4] = -299.9238
+p[ 5] = -306.0667
+p[ 6] = -302.1777
+...
+```
+
+**Instead, I think the magnitude should remain around `1`.
+See https://github.com/ggerganov/llama.cpp/issues/1063#issuecomment-1527730562 for more analysis**
--- /dev/null
+import sys
+import struct
+import json
+import numpy as np
+
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+if len(sys.argv) < 3:
+ print("Usage: convert-h5-to-ggml.py dir-model [use-f32]\n")
+ print(" ftype == 0 -> float32")
+ print(" ftype == 1 -> float16")
+ sys.exit(1)
+
+# output in the same directory as the model
+dir_model = sys.argv[1]
+fname_out = sys.argv[1] + "/ggml-model.bin"
+
+with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f:
+ encoder = json.load(f)
+
+with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
+ hparams = json.load(f)
+
+# possible data types
+# ftype == 0 -> float32
+# ftype == 1 -> float16
+#
+# map from ftype to string
+ftype_str = ["f32", "f16"]
+
+ftype = 1
+if len(sys.argv) > 2:
+ ftype = int(sys.argv[2])
+ if ftype < 0 or ftype > 1:
+ print("Invalid ftype: " + str(ftype))
+ sys.exit(1)
+ fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
+
+
+tokenizer = AutoTokenizer.from_pretrained(dir_model)
+model = AutoModelForCausalLM.from_pretrained(dir_model, low_cpu_mem_usage=True)
+#print (model)
+
+#print(tokenizer.encode('I believe the meaning of life is'))
+
+list_vars = model.state_dict()
+for name in list_vars.keys():
+ print(name, list_vars[name].shape, list_vars[name].dtype)
+
+fout = open(fname_out, "wb")
+
+print(hparams)
+
+fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
+fout.write(struct.pack("i", hparams["vocab_size"]))
+fout.write(struct.pack("i", hparams["max_position_embeddings"]))
+fout.write(struct.pack("i", hparams["hidden_size"]))
+fout.write(struct.pack("i", hparams["num_attention_heads"]))
+fout.write(struct.pack("i", hparams["num_hidden_layers"]))
+fout.write(struct.pack("i", int(hparams["rotary_pct"]*(hparams["hidden_size"]//hparams["num_attention_heads"]))))
+fout.write(struct.pack("i", hparams["use_parallel_residual"]))
+fout.write(struct.pack("i", ftype))
+
+# TODO: temporary hack to not deal with implementing the tokenizer
+dot_token = tokenizer.encode('.')[0]
+for i in range(hparams["vocab_size"]):
+ text = tokenizer.decode([dot_token, i]).encode('utf-8')
+ # remove the first byte (it's always '.')
+ text = text[1:]
+ fout.write(struct.pack("i", len(text)))
+ fout.write(text)
+
+for name in list_vars.keys():
+ data = list_vars[name].squeeze().numpy()
+ print("Processing variable: " + name + " with shape: ", data.shape)
+
+ # we don't need these
+ if name.endswith(".attention.masked_bias") or \
+ name.endswith(".attention.bias") or \
+ name.endswith(".attention.rotary_emb.inv_freq"):
+ print(" Skipping variable: " + name)
+ continue
+
+ n_dims = len(data.shape);
+
+ # ftype == 0 -> float32, ftype == 1 -> float16
+ ftype_cur = 0;
+ if ftype != 0:
+ if name[-7:] == ".weight" and n_dims == 2:
+ print(" Converting to float16")
+ data = data.astype(np.float16)
+ ftype_cur = 1
+ else:
+ print(" Converting to float32")
+ data = data.astype(np.float32)
+ ftype_cur = 0
+ else:
+ if data.dtype != np.float32:
+ print(" Converting to float32")
+ data = data.astype(np.float32)
+ ftype_cur = 0
+
+ # header
+ str = name.encode('utf-8')
+ fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
+ for i in range(n_dims):
+ fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+ fout.write(str);
+
+ # data
+ data.tofile(fout)
+
+fout.close()
+
+print("Done. Output file: " + fname_out)
+print("")
--- /dev/null
+#include "ggml/ggml.h"
+
+#include "common.h"
+#include "common-ggml.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <string>
+#include <vector>
+#include <iostream>
+#include <unistd.h>
+
+// default hparams (StableLM 3B)
+struct gpt_neox_hparams {
+ int32_t n_vocab = 50257;
+ int32_t n_ctx = 4096;
+ int32_t n_embd = 4096;
+ int32_t n_head = 32;
+ int32_t n_layer = 16;
+ int32_t n_rot = 32; // rotary_pct * (n_embd / n_head)
+ int32_t par_res = 1; // 1 = true, 0 = false
+ int32_t ftype = 1;
+};
+
+struct gpt_neox_layer {
+ // pre normalization
+ struct ggml_tensor * ln_1_g;
+ struct ggml_tensor * ln_1_b;
+
+ // attention
+ struct ggml_tensor * c_attn_attn_w;
+ struct ggml_tensor * c_attn_attn_b;
+
+ struct ggml_tensor * c_attn_proj_w;
+ struct ggml_tensor * c_attn_proj_b;
+
+ // post normalization
+ struct ggml_tensor * ln_2_g;
+ struct ggml_tensor * ln_2_b;
+
+ // ff
+ struct ggml_tensor * c_mlp_fc_w;
+ struct ggml_tensor * c_mlp_fc_b;
+
+ struct ggml_tensor * c_mlp_proj_w;
+ struct ggml_tensor * c_mlp_proj_b;
+};
+
+struct gpt_neox_model {
+ gpt_neox_hparams hparams;
+
+ // normalization
+ struct ggml_tensor * ln_f_g;
+ struct ggml_tensor * ln_f_b;
+
+ struct ggml_tensor * wte; // position embedding
+
+ struct ggml_tensor * lmh_g; // language model head
+ //struct ggml_tensor * lmh_b; // language model bias
+
+ std::vector<gpt_neox_layer> layers;
+
+ // key + value memory
+ struct ggml_tensor * memory_k;
+ struct ggml_tensor * memory_v;
+
+ //
+ struct ggml_context * ctx;
+ std::map<std::string, struct ggml_tensor *> tensors;
+};
+
+// load the model's weights from a file
+bool gpt_neox_model_load(const std::string & fname, gpt_neox_model & model, gpt_vocab & vocab) {
+ printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
+
+ auto fin = std::ifstream(fname, std::ios::binary);
+ if (!fin) {
+ fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
+ return false;
+ }
+
+ // verify magic
+ {
+ uint32_t magic;
+ fin.read((char *) &magic, sizeof(magic));
+ if (magic != 0x67676d6c) {
+ fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
+ return false;
+ }
+ }
+
+ // load hparams
+ {
+ auto & hparams = model.hparams;
+
+ fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+ fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
+ fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
+ fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
+ fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+ fin.read((char *) &hparams.n_rot, sizeof(hparams.n_rot));
+ fin.read((char *) &hparams.par_res, sizeof(hparams.par_res));
+ fin.read((char *) &hparams.ftype, sizeof(hparams.ftype));
+
+ printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
+ printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
+ printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
+ printf("%s: n_head = %d\n", __func__, hparams.n_head);
+ printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
+ printf("%s: n_rot = %d\n", __func__, hparams.n_rot);
+ printf("%s: par_res = %d\n", __func__, hparams.par_res);
+ printf("%s: ftype = %d\n", __func__, hparams.ftype);
+ }
+
+ // load vocab
+ {
+ const int32_t n_vocab = model.hparams.n_vocab;
+
+ std::string word;
+ for (int i = 0; i < n_vocab; i++) {
+ uint32_t len;
+ fin.read((char *) &len, sizeof(len));
+
+ word.resize(len);
+ fin.read((char *) word.data(), len);
+
+ vocab.token_to_id[word] = i;
+ vocab.id_to_token[i] = word;
+ }
+ }
+
+ // for the big tensors, we have the option to store the data in 16-bit floats or quantized
+ // in order to save memory and also to speed up the computation
+ ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
+ if (wtype == GGML_TYPE_COUNT) {
+ fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
+ __func__, fname.c_str(), model.hparams.ftype);
+ return false;
+ }
+
+ auto & ctx = model.ctx;
+
+ size_t ctx_size = 0;
+
+ {
+ const auto & hparams = model.hparams;
+
+ const int n_embd = hparams.n_embd;
+ const int n_layer = hparams.n_layer;
+ const int n_ctx = hparams.n_ctx;
+ const int n_vocab = hparams.n_vocab;
+
+ ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
+ ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
+
+ ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // wte
+
+ ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // lmh_g
+ //ctx_size += n_vocab*ggml_type_sizef(GGML_TYPE_F32); // lmh_b
+
+ ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
+ ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
+
+ ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
+ ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
+
+ ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
+ ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
+
+ ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
+ ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
+
+ ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
+ ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
+
+ ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
+ ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
+
+ ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
+ ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
+
+ ctx_size += (6 + 16*n_layer)*256; // object overhead
+
+ printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
+ }
+
+ // create the ggml context
+ {
+ struct ggml_init_params params = {
+ .mem_size = ctx_size,
+ .mem_buffer = NULL,
+ .no_alloc = false,
+ };
+
+ model.ctx = ggml_init(params);
+ if (!model.ctx) {
+ fprintf(stderr, "%s: ggml_init() failed\n", __func__);
+ return false;
+ }
+ }
+
+ // prepare memory for the weights
+ {
+ const auto & hparams = model.hparams;
+
+ const int n_embd = hparams.n_embd;
+ const int n_layer = hparams.n_layer;
+ const int n_ctx = hparams.n_ctx;
+ const int n_vocab = hparams.n_vocab;
+
+ model.layers.resize(n_layer);
+
+ model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
+
+ model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+ model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+
+ model.lmh_g = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
+ //model.lmh_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_vocab);
+
+ // map by name
+ model.tensors["gpt_neox.embed_in.weight"] = model.wte;
+
+ model.tensors["gpt_neox.final_layer_norm.weight"] = model.ln_f_g;
+ model.tensors["gpt_neox.final_layer_norm.bias"] = model.ln_f_b;
+
+ model.tensors["embed_out.weight"] = model.lmh_g;
+ //model.tensors["lm_head.bias"] = model.lmh_b;
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = model.layers[i];
+
+ layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+ layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+
+ layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 3*n_embd);
+ layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
+
+ layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
+ layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+
+ layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+ layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+
+ layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd);
+ layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
+
+ layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
+ layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
+
+ // map by name
+ model.tensors["gpt_neox.layers." + std::to_string(i) + ".input_layernorm.weight"] = layer.ln_1_g;
+ model.tensors["gpt_neox.layers." + std::to_string(i) + ".input_layernorm.bias"] = layer.ln_1_b;
+
+ model.tensors["gpt_neox.layers." + std::to_string(i) + ".attention.query_key_value.weight"] = layer.c_attn_attn_w;
+ model.tensors["gpt_neox.layers." + std::to_string(i) + ".attention.query_key_value.bias"] = layer.c_attn_attn_b;
+
+ model.tensors["gpt_neox.layers." + std::to_string(i) + ".attention.dense.weight"] = layer.c_attn_proj_w;
+ model.tensors["gpt_neox.layers." + std::to_string(i) + ".attention.dense.bias"] = layer.c_attn_proj_b;
+
+ model.tensors["gpt_neox.layers." + std::to_string(i) + ".post_attention_layernorm.weight"] = layer.ln_2_g;
+ model.tensors["gpt_neox.layers." + std::to_string(i) + ".post_attention_layernorm.bias"] = layer.ln_2_b;
+
+ model.tensors["gpt_neox.layers." + std::to_string(i) + ".mlp.dense_h_to_4h.weight"] = layer.c_mlp_fc_w;
+ model.tensors["gpt_neox.layers." + std::to_string(i) + ".mlp.dense_h_to_4h.bias"] = layer.c_mlp_fc_b;
+
+ model.tensors["gpt_neox.layers." + std::to_string(i) + ".mlp.dense_4h_to_h.weight"] = layer.c_mlp_proj_w;
+ model.tensors["gpt_neox.layers." + std::to_string(i) + ".mlp.dense_4h_to_h.bias"] = layer.c_mlp_proj_b;
+ }
+ }
+
+ // key + value memory
+ {
+ const auto & hparams = model.hparams;
+
+ const int n_embd = hparams.n_embd;
+ const int n_layer = hparams.n_layer;
+ const int n_ctx = hparams.n_ctx;
+
+ const int64_t n_mem = n_layer*n_ctx;
+ const int64_t n_elements = n_embd*n_mem;
+
+ model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+ model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
+
+ const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
+
+ printf("%s: memory_size = %8.2f MB, n_mem = %lld\n", __func__, memory_size/1024.0/1024.0, n_mem);
+ }
+
+ // load weights
+ {
+ int n_tensors = 0;
+ size_t total_size = 0;
+
+ printf("%s: ", __func__);
+
+ while (true) {
+ int32_t n_dims;
+ int32_t length;
+ int32_t ttype;
+
+ fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+ fin.read(reinterpret_cast<char *>(&length), sizeof(length));
+ fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
+
+ if (fin.eof()) {
+ break;
+ }
+
+ int32_t nelements = 1;
+ int32_t ne[2] = { 1, 1 };
+ for (int i = 0; i < n_dims; ++i) {
+ fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+ nelements *= ne[i];
+ }
+
+ std::string name(length, 0);
+ fin.read(&name[0], length);
+
+ if (model.tensors.find(name.data()) == model.tensors.end()) {
+ fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
+ return false;
+ }
+
+ auto tensor = model.tensors[name.data()];
+ if (ggml_nelements(tensor) != nelements) {
+ fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
+ return false;
+ }
+
+ if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
+ fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%5d, %5d], expected [%5d, %5d]\n",
+ __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
+ return false;
+ }
+
+ // for debugging
+ if (0) {
+ printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
+ }
+
+ const size_t bpe = ggml_type_size(ggml_type(ttype));
+
+ if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
+ fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+ __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
+ return false;
+ }
+
+ fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
+
+ total_size += ggml_nbytes(tensor);
+ if (++n_tensors % 8 == 0) {
+ printf(".");
+ fflush(stdout);
+ }
+ }
+
+ printf(" done\n");
+
+ printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
+ }
+
+ fin.close();
+
+ return true;
+}
+
+
+// feed-forward network
+ggml_tensor * gpt_neox_ff(
+ const gpt_neox_layer &layer,
+ ggml_context * ctx0,
+ ggml_tensor * inp) {
+ ggml_tensor * cur = ggml_norm(ctx0, inp);
+
+ cur = ggml_add(ctx0,
+ ggml_mul(ctx0,
+ ggml_repeat(ctx0, layer.ln_2_g, cur),
+ cur),
+ ggml_repeat(ctx0, layer.ln_2_b, cur));
+
+ cur = ggml_mul_mat(ctx0,
+ layer.c_mlp_fc_w,
+ cur);
+
+ cur = ggml_add(ctx0,
+ ggml_repeat(ctx0, layer.c_mlp_fc_b, cur),
+ cur);
+
+ // GELU activation
+ cur = ggml_gelu(ctx0, cur);
+
+ // projection
+ // cur = proj_w*cur + proj_b
+ cur = ggml_mul_mat(ctx0,
+ layer.c_mlp_proj_w,
+ cur);
+
+ cur = ggml_add(ctx0,
+ ggml_repeat(ctx0, layer.c_mlp_proj_b, cur),
+ cur);
+ return cur;
+}
+
+// evaluate the transformer
+//
+// - model: the model
+// - n_threads: number of threads to use
+// - n_past: the context size so far
+// - embd_inp: the embeddings of the tokens in the context
+// - embd_w: the predicted logits for the next token
+//
+bool gpt_neox_eval(
+ const gpt_neox_model & model,
+ const int n_threads,
+ const int n_past,
+ const std::vector<gpt_vocab::id> & embd_inp,
+ std::vector<float> & embd_w,
+ size_t & mem_per_token) {
+ const int N = embd_inp.size();
+
+ const auto & hparams = model.hparams;
+
+ const int n_embd = hparams.n_embd;
+ const int n_layer = hparams.n_layer;
+ const int n_ctx = hparams.n_ctx;
+ const int n_head = hparams.n_head;
+ const int n_vocab = hparams.n_vocab;
+ const int n_rot = hparams.n_rot;
+
+ static size_t buf_size = 256u*1024*1024;
+ static void * buf = malloc(buf_size);
+
+ if (mem_per_token > 0 && mem_per_token*N > buf_size) {
+ const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
+ //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
+
+ // reallocate
+ buf_size = buf_size_new;
+ buf = realloc(buf, buf_size);
+ if (buf == nullptr) {
+ fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
+ return false;
+ }
+ }
+
+ struct ggml_init_params params = {
+ .mem_size = buf_size,
+ .mem_buffer = buf,
+ .no_alloc = false,
+ };
+
+ struct ggml_context * ctx0 = ggml_init(params);
+ struct ggml_cgraph gf = {};
+ gf.n_threads = n_threads;
+
+ struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+ memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
+
+ // wte
+ struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.wte, embd);
+
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * cur;
+
+ // self-attention
+ {
+ {
+ cur = ggml_norm(ctx0, inpL);
+
+ cur = ggml_add(ctx0,
+ ggml_mul(ctx0,
+ ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
+ cur),
+ ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
+ }
+
+ // compute QKV
+ {
+ cur = ggml_mul_mat(ctx0,
+ model.layers[il].c_attn_attn_w,
+ cur);
+
+ cur = ggml_add(ctx0,
+ ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur),
+ cur);
+ }
+
+ struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd/n_head, n_head, N, cur->nb[1]/n_head, cur->nb[1], 0*sizeof(float)*n_embd/n_head));
+ struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd/n_head, n_head, N, cur->nb[1]/n_head, cur->nb[1], 1*sizeof(float)*n_embd/n_head));
+ struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd/n_head, n_head, N, cur->nb[1]/n_head, cur->nb[1], 2*sizeof(float)*n_embd/n_head));
+
+ // using mode = 2 for GPT-NeoX mode
+ Qcur = ggml_rope(ctx0, Qcur, n_past, n_rot, 2);
+ Kcur = ggml_rope(ctx0, Kcur, n_past, n_rot, 2);
+
+ // store key and value to memory
+ {
+ Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd, N));
+
+ struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
+ struct ggml_tensor * v = ggml_view_2d(ctx0, model.memory_v, N, n_embd,
+ ( n_ctx)*ggml_element_size(model.memory_v),
+ (il*n_ctx)*ggml_element_size(model.memory_v)*n_embd + n_past*ggml_element_size(model.memory_v));
+
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
+ ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
+ }
+
+ // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
+ struct ggml_tensor * Q =
+ ggml_permute(ctx0,
+ Qcur,
+ 0, 2, 1, 3);
+
+ // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
+ struct ggml_tensor * K =
+ ggml_permute(ctx0,
+ ggml_reshape_3d(ctx0,
+ ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
+ n_embd/n_head, n_head, n_past + N),
+ 0, 2, 1, 3);
+
+ // K * Q
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+
+ // KQ_scaled = KQ / sqrt(n_embd/n_head)
+ struct ggml_tensor * KQ_scaled =
+ ggml_scale(ctx0,
+ KQ,
+ ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
+ );
+
+ // KQ_masked = mask_past(KQ_scaled)
+ struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
+
+ // KQ = soft_max(KQ_masked)
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+
+ // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
+ struct ggml_tensor * V =
+ ggml_view_3d(ctx0, model.memory_v,
+ n_past + N, n_embd/n_head, n_head,
+ n_ctx*ggml_element_size(model.memory_v),
+ n_ctx*ggml_element_size(model.memory_v)*n_embd/n_head,
+ il*n_ctx*ggml_element_size(model.memory_v)*n_embd);
+
+ // KQV = transpose(V) * KQ_soft_max
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+
+ // KQV_merged = KQV.permute(0, 2, 1, 3)
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+ // cur = KQV_merged.contiguous().view(n_embd, N)
+ cur = ggml_cpy(ctx0,
+ KQV_merged,
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+
+ // projection
+ {
+ cur = ggml_mul_mat(ctx0,
+ model.layers[il].c_attn_proj_w,
+ cur);
+
+ cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur), cur);
+ }
+ }
+
+ if (hparams.par_res == 0) {
+ struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpL);
+
+ cur = gpt_neox_ff(model.layers[il], ctx0, inpFF);
+
+ // input for next layer
+ inpL = ggml_add(ctx0, cur, inpFF);
+ } else {
+ struct ggml_tensor * inpFF = cur;
+
+ // this is independent of the self-attention result, so it could be done in parallel to the self-attention
+ // note here we pass inpL instead of cur
+ cur = gpt_neox_ff(model.layers[il], ctx0, inpL);
+
+ // layer input + FF
+ cur = ggml_add(ctx0, cur, inpFF);
+
+ // input for next layer
+ inpL = ggml_add(ctx0, cur, inpL);
+ }
+ }
+
+ // norm
+ {
+ inpL = ggml_norm(ctx0, inpL);
+
+ // inpL = ln_f_g*inpL + ln_f_b
+ inpL = ggml_add(ctx0,
+ ggml_mul(ctx0,
+ ggml_repeat(ctx0, model.ln_f_g, inpL),
+ inpL),
+ ggml_repeat(ctx0, model.ln_f_b, inpL));
+ }
+
+ // lm_head
+ {
+ inpL = ggml_mul_mat(ctx0, model.lmh_g, inpL);
+
+ //inpL = ggml_add(ctx0,
+ // ggml_repeat(ctx0, model.lmh_b, inpL),
+ // inpL);
+ }
+
+ // logits -> probs
+ //inpL = ggml_soft_max(ctx0, inpL);
+
+ // run the computation
+ ggml_build_forward_expand(&gf, inpL);
+ ggml_graph_compute (ctx0, &gf);
+
+ //if (n_past%100 == 0) {
+ // ggml_graph_print (&gf);
+ // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
+ //}
+
+ //embd_w.resize(n_vocab*N);
+ //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
+
+ // return result for just the last token
+ embd_w.resize(n_vocab);
+ memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
+
+ if (mem_per_token == 0) {
+ mem_per_token = ggml_used_mem(ctx0)/N;
+ }
+ //printf("used_mem = %zu\n", ggml_used_mem(ctx0));
+
+ ggml_free(ctx0);
+
+ return true;
+}
+
+int main(int argc, char ** argv) {
+ const int64_t t_main_start_us = ggml_time_us();
+
+ gpt_params params;
+ params.model = "models/stablelm-base-alpha-3b/ggml-model-f16.bin";
+
+ if (gpt_params_parse(argc, argv, params) == false) {
+ return 1;
+ }
+
+ if (params.seed < 0) {
+ params.seed = time(NULL);
+ }
+
+ printf("%s: seed = %d\n", __func__, params.seed);
+
+ std::mt19937 rng(params.seed);
+ if (params.prompt.empty()) {
+ if( !isatty(STDIN_FILENO) ){
+ std::string line;
+ while( std::getline(std::cin, line) ){
+ params.prompt = params.prompt + "\n" + line;
+ }
+ } else {
+ params.prompt = gpt_random_prompt(rng);
+ }
+ }
+
+ int64_t t_load_us = 0;
+
+ gpt_vocab vocab;
+ gpt_neox_model model;
+
+ // load the model
+ {
+ const int64_t t_start_us = ggml_time_us();
+
+ if (!gpt_neox_model_load(params.model, model, vocab)) {
+ fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
+ return 1;
+ }
+
+ t_load_us = ggml_time_us() - t_start_us;
+ }
+
+ int n_past = 0;
+
+ int64_t t_sample_us = 0;
+ int64_t t_predict_us = 0;
+
+ std::vector<float> logits;
+
+ // tokenize the prompt
+ std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
+
+ params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
+
+ printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
+ for (int i = 0; i < embd_inp.size(); i++) {
+ printf("%s: token[%d] = %6d, %s\n", __func__, i, embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
+ }
+ printf("\n");
+
+ std::vector<gpt_vocab::id> embd;
+
+ // determine the required inference memory per token:
+ size_t mem_per_token = 0;
+ gpt_neox_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
+
+ for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
+ // predict
+ if (embd.size() > 0) {
+ const int64_t t_start_us = ggml_time_us();
+
+ if (!gpt_neox_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
+ printf("Failed to predict\n");
+ return 1;
+ }
+
+ t_predict_us += ggml_time_us() - t_start_us;
+ }
+
+ n_past += embd.size();
+ embd.clear();
+
+ if (i >= embd_inp.size()) {
+ // sample next token
+ const int top_k = params.top_k;
+ const float top_p = params.top_p;
+ const float temp = params.temp;
+
+ const int n_vocab = model.hparams.n_vocab;
+
+ gpt_vocab::id id = 0;
+
+ {
+ const int64_t t_start_sample_us = ggml_time_us();
+
+ id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
+
+ t_sample_us += ggml_time_us() - t_start_sample_us;
+ }
+
+ // add it to the context
+ embd.push_back(id);
+ } else {
+ // if here, it means we are still processing the input prompt
+ for (int k = i; k < embd_inp.size(); k++) {
+ embd.push_back(embd_inp[k]);
+ if (embd.size() > params.n_batch) {
+ break;
+ }
+ }
+ i += embd.size() - 1;
+ }
+
+ // display text
+ for (auto id : embd) {
+ printf("%s", vocab.id_to_token[id].c_str());
+ }
+ fflush(stdout);
+
+ // end of text token
+ if (embd.back() == 0) {
+ break;
+ }
+ }
+
+ // report timing
+ {
+ const int64_t t_main_end_us = ggml_time_us();
+
+ printf("\n\n");
+ printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
+ printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
+ printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
+ printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);
+ printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
+ }
+
+ ggml_free(model.ctx);
+
+ return 0;
+}
--- /dev/null
+#include "ggml/ggml.h"
+
+#include "common.h"
+#include "common-ggml.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <map>
+#include <string>
+#include <vector>
+#include <regex>
+
+// default hparams (StableLM 3B)
+struct gpt_neox_hparams {
+ int32_t n_vocab = 50257;
+ int32_t n_ctx = 4096;
+ int32_t n_embd = 4096;
+ int32_t n_head = 32;
+ int32_t n_layer = 16;
+ int32_t n_rot = 32; // 0.25 * (n_embd / n_head)
+ int32_t par_res = 1; // 1 = true, 0 = false
+ int32_t ftype = 1;
+};
+
+// quantize a model
+bool gpt_neox_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) {
+ gpt_vocab vocab;
+
+ printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
+
+ auto finp = std::ifstream(fname_inp, std::ios::binary);
+ if (!finp) {
+ fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str());
+ return false;
+ }
+
+ auto fout = std::ofstream(fname_out, std::ios::binary);
+ if (!fout) {
+ fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
+ return false;
+ }
+
+ // verify magic
+ {
+ uint32_t magic;
+ finp.read((char *) &magic, sizeof(magic));
+ if (magic != 0x67676d6c) {
+ fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
+ return false;
+ }
+
+ fout.write((char *) &magic, sizeof(magic));
+ }
+
+ gpt_neox_hparams hparams;
+
+ // load hparams
+ {
+ finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+ finp.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
+ finp.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
+ finp.read((char *) &hparams.n_head, sizeof(hparams.n_head));
+ finp.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+ finp.read((char *) &hparams.n_rot, sizeof(hparams.n_rot));
+ finp.read((char *) &hparams.par_res, sizeof(hparams.par_res));
+ finp.read((char *) &hparams.ftype, sizeof(hparams.ftype));
+
+ printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
+ printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
+ printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
+ printf("%s: n_head = %d\n", __func__, hparams.n_head);
+ printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
+ printf("%s: par_res = %d\n", __func__, hparams.par_res);
+ printf("%s: ftype = %d\n", __func__, hparams.ftype);
+
+ fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
+ fout.write((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
+ fout.write((char *) &hparams.n_embd, sizeof(hparams.n_embd));
+ fout.write((char *) &hparams.n_head, sizeof(hparams.n_head));
+ fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer));
+ fout.write((char *) &hparams.n_rot, sizeof(hparams.n_rot));
+ fout.write((char *) &hparams.par_res, sizeof(hparams.par_res));
+ fout.write((char *) &ftype, sizeof(hparams.ftype));
+ }
+
+ // load vocab
+ {
+ const int32_t n_vocab = hparams.n_vocab;
+
+ std::string word;
+ for (int i = 0; i < n_vocab; i++) {
+ uint32_t len;
+ finp.read ((char *) &len, sizeof(len));
+ fout.write((char *) &len, sizeof(len));
+
+ word.resize(len);
+ finp.read ((char *) word.data(), len);
+ fout.write((char *) word.data(), len);
+
+ vocab.token_to_id[word] = i;
+ vocab.id_to_token[i] = word;
+ }
+ }
+
+ // regexes of tensor names to be quantized
+ const std::vector<std::string> to_quant = {
+ ".*weight",
+ };
+
+ if (!ggml_common_quantize_0(finp, fout, ftype, to_quant, {})) {
+ fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__, fname_inp.c_str());
+ return false;
+ }
+
+ finp.close();
+ fout.close();
+
+ return true;
+}
+
+// usage:
+// ./gpt-neox-quantize models/stalellm2-117M/ggml-model.bin models/stablelm2-117M/ggml-model-quant.bin type
+//
+int main(int argc, char ** argv) {
+ if (argc != 4) {
+ fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
+ ggml_print_ftypes(stderr);
+ return 1;
+ }
+
+ // needed to initialize f16 tables
+ {
+ struct ggml_init_params params = { 0, NULL, false };
+ struct ggml_context * ctx = ggml_init(params);
+ ggml_free(ctx);
+ }
+
+ const std::string fname_inp = argv[1];
+ const std::string fname_out = argv[2];
+
+ const ggml_ftype ftype = ggml_parse_ftype(argv[3]);
+
+ const int64_t t_main_start_us = ggml_time_us();
+
+ int64_t t_quantize_us = 0;
+
+ // load the model
+ {
+ const int64_t t_start_us = ggml_time_us();
+
+ if (!gpt_neox_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) {
+ fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
+ return 1;
+ }
+
+ t_quantize_us = ggml_time_us() - t_start_us;
+ }
+
+ // report timing
+ {
+ const int64_t t_main_end_us = ggml_time_us();
+
+ printf("\n");
+ printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f);
+ printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
+ }
+
+ return 0;
+}
+++ /dev/null
-#
-# stablelm
-
-set(TEST_TARGET stablelm)
-add_executable(${TEST_TARGET} main.cpp)
-target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
-
-#
-# stablelm-quantize
-
-set(TEST_TARGET stablelm-quantize)
-add_executable(${TEST_TARGET} quantize.cpp)
-target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
+++ /dev/null
-# StableLM
-
-Transformer architecture: GPT-NeoX
-
-Ref: https://github.com/stability-AI/stableLM/#stablelm-alpha
-
-## Usage
-
-```bash
-# get the repo and build it
-git clone https://github.com/ggerganov/ggml
-cd ggml
-mkdir build && cd build
-cmake ..
-make -j
-
-# get the StableLM 3B Alpha model
-git clone https://huggingface.co/stabilityai/stablelm-base-alpha-3b
-
-# convert model to FP16
-python3 ../examples/stablelm/convert-h5-to-ggml.py ./stablelm-base-alpha-3b/ 1
-
-# run inference using FP16 precision
-make -j && ./bin/stablelm -m ./stablelm-base-alpha-3b/ggml-model-f16.bin -p "I believe the meaning of life is" -t 8 -n 64
-
-main: seed = 1681940611
-stablelm_model_load: loading model from 'models/stablelm-base-alpha-3b/ggml-model-f16.bin' - please wait ...
-stablelm_model_load: n_vocab = 50688
-stablelm_model_load: n_ctx = 4096
-stablelm_model_load: n_embd = 4096
-stablelm_model_load: n_head = 32
-stablelm_model_load: n_layer = 16
-stablelm_model_load: n_rot = 32
-stablelm_model_load: ftype = 1
-stablelm_model_load: ggml ctx size = 10011.10 MB
-stablelm_model_load: memory_size = 2048.00 MB, n_mem = 65536
-stablelm_model_load: ................................ done
-stablelm_model_load: model size = 6939.28 MB / num tensors = 260
-main: number of tokens in prompt = 7
-main: token[0] = 42, I
-main: token[1] = 2868, believe
-main: token[2] = 253, the
-main: token[3] = 4495, meaning
-main: token[4] = 273, of
-main: token[5] = 1495, life
-main: token[6] = 310, is
-
-I believe the meaning of life is to grow, to find a way, to love, to find an appreciation for life, and to live it with all of its beauty.
-
-For I am the child of God. I am the offspring of God's love. I am the offspring of the light of the world. I am the offspring of the
-
-main: mem per token = 12186760 bytes
-main: load time = 2118.55 ms
-main: sample time = 9.59 ms
-main: predict time = 4474.07 ms / 63.92 ms per token
-main: total time = 6911.26 ms
-```
-
-## 4-bit integer quantization mode
-
-```bash
-# quantize the model to 4-bits using Q4_3 quantization
-./bin/stablelm-quantize ./stablelm-base-alpha-3b/ggml-model-f16.bin ./stablelm-base-alpha-3b/ggml-model-q4_3.bin 6
-
-# run the quantized model
-./bin/stablelm -m ./stablelm-base-alpha-3b/ggml-model-q4_3.bin -p "I believe the meaning of life is" -t 8 -n 64
-
-main: seed = 1682021489
-stablelm_model_load: loading model from 'models/stablelm-base-alpha-3b/ggml-model-q4_3.bin' - please wait ...
-stablelm_model_load: n_vocab = 50688
-stablelm_model_load: n_ctx = 4096
-stablelm_model_load: n_embd = 4096
-stablelm_model_load: n_head = 32
-stablelm_model_load: n_layer = 16
-stablelm_model_load: n_rot = 32
-stablelm_model_load: ftype = 6
-stablelm_model_load: ggml ctx size = 5676.10 MB
-stablelm_model_load: memory_size = 1024.00 MB, n_mem = 65536
-stablelm_model_load: ........................ done
-stablelm_model_load: model size = 2604.28 MB / num tensors = 196
-main: number of tokens in prompt = 7
-main: token[0] = 42, I
-main: token[1] = 2868, believe
-main: token[2] = 253, the
-main: token[3] = 4495, meaning
-main: token[4] = 273, of
-main: token[5] = 1495, life
-main: token[6] = 310, is
-
-I believe the meaning of life is to love and be loved. The last three verses were enough to tie us all together. If you love someone you love them all. There are some things in this world that are just not equal in Heaven. - Be here in this moment.
-
-This world is not what is outside of us. It is what
-
-main: mem per token = 12958024 bytes
-main: load time = 850.51 ms
-main: sample time = 9.95 ms
-main: predict time = 3103.81 ms / 44.34 ms per token
-main: total time = 4177.68 ms
-
-```
-
-## Notes
-
-- No guarantees for correctness
-- The tokenizer is currently hacked - probably works only for English
-- Non-parallel residual is not supported
-- Contributions and improvements are welcome
-
-## Note about possible bug
-
-**There might be some issue with this implementation - not 100% sure.
-The embeddings magnitude increases after each layer which is unexpected.
-To observe this, uncomment the following line:**
-
-https://github.com/ggerganov/ggml/blob/abea4b7609c14b837015ab625e3ac36c4708dd03/src/ggml.c#L9208
-
-```
-...
-p[ 0] = 65.5842
-p[ 1] = 61.6951
-p[ 2] = 59.3500
-p[ 3] = 61.2421
-p[ 4] = 65.9653
-p[ 5] = 59.4936
-p[ 6] = 58.4164
-p[ 0] = -209.6351
-p[ 1] = -214.0987
-p[ 2] = -217.0928
-p[ 3] = -215.0267
-p[ 4] = -208.2430
-p[ 5] = -215.3692
-p[ 6] = -214.1981
-p[ 0] = -301.0286
-p[ 1] = -308.6521
-p[ 2] = -310.7513
-p[ 3] = -307.0832
-p[ 4] = -299.9238
-p[ 5] = -306.0667
-p[ 6] = -302.1777
-...
-```
-
-**Instead, I think the magnitude should remain around `1`.
-See https://github.com/ggerganov/llama.cpp/issues/1063#issuecomment-1527730562 for more analysis**
+++ /dev/null
-import sys
-import struct
-import json
-import torch
-import numpy as np
-
-from transformers import AutoModelForCausalLM, AutoTokenizer
-
-if len(sys.argv) < 3:
- print("Usage: convert-h5-to-ggml.py dir-model [use-f32]\n")
- print(" ftype == 0 -> float32")
- print(" ftype == 1 -> float16")
- sys.exit(1)
-
-# output in the same directory as the model
-dir_model = sys.argv[1]
-fname_out = sys.argv[1] + "/ggml-model.bin"
-
-with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f:
- encoder = json.load(f)
-
-with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
- hparams = json.load(f)
-
-# possible data types
-# ftype == 0 -> float32
-# ftype == 1 -> float16
-#
-# map from ftype to string
-ftype_str = ["f32", "f16"]
-
-ftype = 1
-if len(sys.argv) > 2:
- ftype = int(sys.argv[2])
- if ftype < 0 or ftype > 1:
- print("Invalid ftype: " + str(ftype))
- sys.exit(1)
- fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
-
-
-tokenizer = AutoTokenizer.from_pretrained(dir_model)
-model = AutoModelForCausalLM.from_pretrained(dir_model, low_cpu_mem_usage=True)
-#print (model)
-
-#print(tokenizer.encode('I believe the meaning of life is'))
-
-list_vars = model.state_dict()
-for name in list_vars.keys():
- print(name, list_vars[name].shape, list_vars[name].dtype)
-
-fout = open(fname_out, "wb")
-
-print(hparams)
-
-fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
-fout.write(struct.pack("i", hparams["vocab_size"]))
-fout.write(struct.pack("i", hparams["max_position_embeddings"]))
-fout.write(struct.pack("i", hparams["hidden_size"]))
-fout.write(struct.pack("i", hparams["num_attention_heads"]))
-fout.write(struct.pack("i", hparams["num_hidden_layers"]))
-fout.write(struct.pack("i", int(hparams["rotary_pct"]*(hparams["hidden_size"]//hparams["num_attention_heads"]))))
-fout.write(struct.pack("i", ftype))
-
-# TODO: temporary hack to not deal with implementing the tokenizer
-dot_token = tokenizer.encode('.')[0]
-for i in range(hparams["vocab_size"]):
- text = tokenizer.decode([dot_token, i]).encode('utf-8')
- # remove the first byte (it's always '.')
- text = text[1:]
- fout.write(struct.pack("i", len(text)))
- fout.write(text)
-
-for name in list_vars.keys():
- data = list_vars[name].squeeze().numpy()
- print("Processing variable: " + name + " with shape: ", data.shape)
-
- # we don't need these
- if name.endswith(".attention.masked_bias") or \
- name.endswith(".attention.bias") or \
- name.endswith(".attention.rotary_emb.inv_freq"):
- print(" Skipping variable: " + name)
- continue
-
- n_dims = len(data.shape);
-
- # ftype == 0 -> float32, ftype == 1 -> float16
- ftype_cur = 0;
- if ftype != 0:
- if name[-7:] == ".weight" and n_dims == 2:
- print(" Converting to float16")
- data = data.astype(np.float16)
- ftype_cur = 1
- else:
- print(" Converting to float32")
- data = data.astype(np.float32)
- ftype_cur = 0
- else:
- if data.dtype != np.float32:
- print(" Converting to float32")
- data = data.astype(np.float32)
- ftype_cur = 0
-
- # header
- str = name.encode('utf-8')
- fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
- for i in range(n_dims):
- fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
- fout.write(str);
-
- # data
- data.tofile(fout)
-
-fout.close()
-
-print("Done. Output file: " + fname_out)
-print("")
+++ /dev/null
-#include "ggml/ggml.h"
-
-#include "common.h"
-#include "common-ggml.h"
-
-#include <cassert>
-#include <cmath>
-#include <cstdio>
-#include <cstring>
-#include <fstream>
-#include <map>
-#include <string>
-#include <vector>
-#include <iostream>
-#include <unistd.h>
-
-// default hparams (StableLM 3B)
-struct stablelm_hparams {
- int32_t n_vocab = 50257;
- int32_t n_ctx = 4096;
- int32_t n_embd = 4096;
- int32_t n_head = 32;
- int32_t n_layer = 16;
- int32_t n_rot = 32; // rotary_pct * (n_embd / n_head)
- int32_t ftype = 1;
-};
-
-struct stablelm_layer {
- // pre normalization
- struct ggml_tensor * ln_1_g;
- struct ggml_tensor * ln_1_b;
-
- // attention
- struct ggml_tensor * c_attn_attn_w;
- struct ggml_tensor * c_attn_attn_b;
-
- struct ggml_tensor * c_attn_proj_w;
- struct ggml_tensor * c_attn_proj_b;
-
- // post normalization
- struct ggml_tensor * ln_2_g;
- struct ggml_tensor * ln_2_b;
-
- // ff
- struct ggml_tensor * c_mlp_fc_w;
- struct ggml_tensor * c_mlp_fc_b;
-
- struct ggml_tensor * c_mlp_proj_w;
- struct ggml_tensor * c_mlp_proj_b;
-};
-
-struct stablelm_model {
- stablelm_hparams hparams;
-
- // normalization
- struct ggml_tensor * ln_f_g;
- struct ggml_tensor * ln_f_b;
-
- struct ggml_tensor * wte; // position embedding
-
- struct ggml_tensor * lmh_g; // language model head
- //struct ggml_tensor * lmh_b; // language model bias
-
- std::vector<stablelm_layer> layers;
-
- // key + value memory
- struct ggml_tensor * memory_k;
- struct ggml_tensor * memory_v;
-
- //
- struct ggml_context * ctx;
- std::map<std::string, struct ggml_tensor *> tensors;
-};
-
-// load the model's weights from a file
-bool stablelm_model_load(const std::string & fname, stablelm_model & model, gpt_vocab & vocab) {
- printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
-
- auto fin = std::ifstream(fname, std::ios::binary);
- if (!fin) {
- fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
- return false;
- }
-
- // verify magic
- {
- uint32_t magic;
- fin.read((char *) &magic, sizeof(magic));
- if (magic != 0x67676d6c) {
- fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
- return false;
- }
- }
-
- // load hparams
- {
- auto & hparams = model.hparams;
-
- fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
- fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
- fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
- fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
- fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
- fin.read((char *) &hparams.n_rot, sizeof(hparams.n_rot));
- fin.read((char *) &hparams.ftype, sizeof(hparams.ftype));
-
- printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
- printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
- printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
- printf("%s: n_head = %d\n", __func__, hparams.n_head);
- printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
- printf("%s: n_rot = %d\n", __func__, hparams.n_rot);
- printf("%s: ftype = %d\n", __func__, hparams.ftype);
- }
-
- // load vocab
- {
- const int32_t n_vocab = model.hparams.n_vocab;
-
- std::string word;
- for (int i = 0; i < n_vocab; i++) {
- uint32_t len;
- fin.read((char *) &len, sizeof(len));
-
- word.resize(len);
- fin.read((char *) word.data(), len);
-
- vocab.token_to_id[word] = i;
- vocab.id_to_token[i] = word;
- }
- }
-
- // for the big tensors, we have the option to store the data in 16-bit floats or quantized
- // in order to save memory and also to speed up the computation
- ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
- if (wtype == GGML_TYPE_COUNT) {
- fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
- __func__, fname.c_str(), model.hparams.ftype);
- return false;
- }
-
- auto & ctx = model.ctx;
-
- size_t ctx_size = 0;
-
- {
- const auto & hparams = model.hparams;
-
- const int n_embd = hparams.n_embd;
- const int n_layer = hparams.n_layer;
- const int n_ctx = hparams.n_ctx;
- const int n_vocab = hparams.n_vocab;
-
- ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
- ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
-
- ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // wte
-
- ctx_size += n_embd*n_vocab*ggml_type_sizef(wtype); // lmh_g
- //ctx_size += n_vocab*ggml_type_sizef(GGML_TYPE_F32); // lmh_b
-
- ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
- ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
-
- ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
- ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
-
- ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
- ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
-
- ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
- ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
-
- ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
- ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
-
- ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
- ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
-
- ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
- ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
-
- ctx_size += (6 + 16*n_layer)*256; // object overhead
-
- printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
- }
-
- // create the ggml context
- {
- struct ggml_init_params params = {
- .mem_size = ctx_size,
- .mem_buffer = NULL,
- .no_alloc = false,
- };
-
- model.ctx = ggml_init(params);
- if (!model.ctx) {
- fprintf(stderr, "%s: ggml_init() failed\n", __func__);
- return false;
- }
- }
-
- // prepare memory for the weights
- {
- const auto & hparams = model.hparams;
-
- const int n_embd = hparams.n_embd;
- const int n_layer = hparams.n_layer;
- const int n_ctx = hparams.n_ctx;
- const int n_vocab = hparams.n_vocab;
-
- model.layers.resize(n_layer);
-
- model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
-
- model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
- model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
-
- model.lmh_g = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
- //model.lmh_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_vocab);
-
- // map by name
- model.tensors["gpt_neox.embed_in.weight"] = model.wte;
-
- model.tensors["gpt_neox.final_layer_norm.weight"] = model.ln_f_g;
- model.tensors["gpt_neox.final_layer_norm.bias"] = model.ln_f_b;
-
- model.tensors["embed_out.weight"] = model.lmh_g;
- //model.tensors["lm_head.bias"] = model.lmh_b;
-
- for (int i = 0; i < n_layer; ++i) {
- auto & layer = model.layers[i];
-
- layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
- layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
-
- layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 3*n_embd);
- layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
-
- layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
- layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
-
- layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
- layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
-
- layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd);
- layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
-
- layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
- layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
-
- // map by name
- model.tensors["gpt_neox.layers." + std::to_string(i) + ".input_layernorm.weight"] = layer.ln_1_g;
- model.tensors["gpt_neox.layers." + std::to_string(i) + ".input_layernorm.bias"] = layer.ln_1_b;
-
- model.tensors["gpt_neox.layers." + std::to_string(i) + ".attention.query_key_value.weight"] = layer.c_attn_attn_w;
- model.tensors["gpt_neox.layers." + std::to_string(i) + ".attention.query_key_value.bias"] = layer.c_attn_attn_b;
-
- model.tensors["gpt_neox.layers." + std::to_string(i) + ".attention.dense.weight"] = layer.c_attn_proj_w;
- model.tensors["gpt_neox.layers." + std::to_string(i) + ".attention.dense.bias"] = layer.c_attn_proj_b;
-
- model.tensors["gpt_neox.layers." + std::to_string(i) + ".post_attention_layernorm.weight"] = layer.ln_2_g;
- model.tensors["gpt_neox.layers." + std::to_string(i) + ".post_attention_layernorm.bias"] = layer.ln_2_b;
-
- model.tensors["gpt_neox.layers." + std::to_string(i) + ".mlp.dense_h_to_4h.weight"] = layer.c_mlp_fc_w;
- model.tensors["gpt_neox.layers." + std::to_string(i) + ".mlp.dense_h_to_4h.bias"] = layer.c_mlp_fc_b;
-
- model.tensors["gpt_neox.layers." + std::to_string(i) + ".mlp.dense_4h_to_h.weight"] = layer.c_mlp_proj_w;
- model.tensors["gpt_neox.layers." + std::to_string(i) + ".mlp.dense_4h_to_h.bias"] = layer.c_mlp_proj_b;
- }
- }
-
- // key + value memory
- {
- const auto & hparams = model.hparams;
-
- const int n_embd = hparams.n_embd;
- const int n_layer = hparams.n_layer;
- const int n_ctx = hparams.n_ctx;
-
- const int64_t n_mem = n_layer*n_ctx;
- const int64_t n_elements = n_embd*n_mem;
-
- model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
- model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
-
- const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
-
- printf("%s: memory_size = %8.2f MB, n_mem = %lld\n", __func__, memory_size/1024.0/1024.0, n_mem);
- }
-
- // load weights
- {
- int n_tensors = 0;
- size_t total_size = 0;
-
- printf("%s: ", __func__);
-
- while (true) {
- int32_t n_dims;
- int32_t length;
- int32_t ttype;
-
- fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
- fin.read(reinterpret_cast<char *>(&length), sizeof(length));
- fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
-
- if (fin.eof()) {
- break;
- }
-
- int32_t nelements = 1;
- int32_t ne[2] = { 1, 1 };
- for (int i = 0; i < n_dims; ++i) {
- fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
- nelements *= ne[i];
- }
-
- std::string name(length, 0);
- fin.read(&name[0], length);
-
- if (model.tensors.find(name.data()) == model.tensors.end()) {
- fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
- return false;
- }
-
- auto tensor = model.tensors[name.data()];
- if (ggml_nelements(tensor) != nelements) {
- fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
- return false;
- }
-
- if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
- fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%5d, %5d], expected [%5d, %5d]\n",
- __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
- return false;
- }
-
- // for debugging
- if (0) {
- printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
- }
-
- const size_t bpe = ggml_type_size(ggml_type(ttype));
-
- if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
- fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
- __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
- return false;
- }
-
- fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
-
- total_size += ggml_nbytes(tensor);
- if (++n_tensors % 8 == 0) {
- printf(".");
- fflush(stdout);
- }
- }
-
- printf(" done\n");
-
- printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
- }
-
- fin.close();
-
- return true;
-}
-
-// evaluate the transformer
-//
-// - model: the model
-// - n_threads: number of threads to use
-// - n_past: the context size so far
-// - embd_inp: the embeddings of the tokens in the context
-// - embd_w: the predicted logits for the next token
-//
-bool stablelm_eval(
- const stablelm_model & model,
- const int n_threads,
- const int n_past,
- const std::vector<gpt_vocab::id> & embd_inp,
- std::vector<float> & embd_w,
- size_t & mem_per_token) {
- const int N = embd_inp.size();
-
- const auto & hparams = model.hparams;
-
- const int n_embd = hparams.n_embd;
- const int n_layer = hparams.n_layer;
- const int n_ctx = hparams.n_ctx;
- const int n_head = hparams.n_head;
- const int n_vocab = hparams.n_vocab;
- const int n_rot = hparams.n_rot;
-
- static size_t buf_size = 256u*1024*1024;
- static void * buf = malloc(buf_size);
-
- if (mem_per_token > 0 && mem_per_token*N > buf_size) {
- const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
- //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
-
- // reallocate
- buf_size = buf_size_new;
- buf = realloc(buf, buf_size);
- if (buf == nullptr) {
- fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
- return false;
- }
- }
-
- struct ggml_init_params params = {
- .mem_size = buf_size,
- .mem_buffer = buf,
- .no_alloc = false,
- };
-
- struct ggml_context * ctx0 = ggml_init(params);
- struct ggml_cgraph gf = {};
- gf.n_threads = n_threads;
-
- struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
- memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
-
- // wte
- struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.wte, embd);
-
- for (int il = 0; il < n_layer; ++il) {
- struct ggml_tensor * cur;
-
- // self-attention
- {
- {
- cur = ggml_norm(ctx0, inpL);
-
- cur = ggml_add(ctx0,
- ggml_mul(ctx0,
- ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
- cur),
- ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
- }
-
- // compute QKV
- {
- cur = ggml_mul_mat(ctx0,
- model.layers[il].c_attn_attn_w,
- cur);
-
- cur = ggml_add(ctx0,
- ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur),
- cur);
- }
-
- struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd/n_head, n_head, N, cur->nb[1]/n_head, cur->nb[1], 0*sizeof(float)*n_embd/n_head));
- struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd/n_head, n_head, N, cur->nb[1]/n_head, cur->nb[1], 1*sizeof(float)*n_embd/n_head));
- struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd/n_head, n_head, N, cur->nb[1]/n_head, cur->nb[1], 2*sizeof(float)*n_embd/n_head));
-
- // using mode = 2 for GPT-NeoX mode
- Qcur = ggml_rope(ctx0, Qcur, n_past, n_rot, 2);
- Kcur = ggml_rope(ctx0, Kcur, n_past, n_rot, 2);
-
- // store key and value to memory
- {
- Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd, N));
-
- struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
- struct ggml_tensor * v = ggml_view_2d(ctx0, model.memory_v, N, n_embd,
- ( n_ctx)*ggml_element_size(model.memory_v),
- (il*n_ctx)*ggml_element_size(model.memory_v)*n_embd + n_past*ggml_element_size(model.memory_v));
-
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
- }
-
- // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
- struct ggml_tensor * Q =
- ggml_permute(ctx0,
- Qcur,
- 0, 2, 1, 3);
-
- // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
- struct ggml_tensor * K =
- ggml_permute(ctx0,
- ggml_reshape_3d(ctx0,
- ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
- n_embd/n_head, n_head, n_past + N),
- 0, 2, 1, 3);
-
- // K * Q
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
-
- // KQ_scaled = KQ / sqrt(n_embd/n_head)
- struct ggml_tensor * KQ_scaled =
- ggml_scale(ctx0,
- KQ,
- ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
- );
-
- // KQ_masked = mask_past(KQ_scaled)
- struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
-
- // KQ = soft_max(KQ_masked)
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
-
- // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
- struct ggml_tensor * V =
- ggml_view_3d(ctx0, model.memory_v,
- n_past + N, n_embd/n_head, n_head,
- n_ctx*ggml_element_size(model.memory_v),
- n_ctx*ggml_element_size(model.memory_v)*n_embd/n_head,
- il*n_ctx*ggml_element_size(model.memory_v)*n_embd);
-
- // KQV = transpose(V) * KQ_soft_max
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
-
- // KQV_merged = KQV.permute(0, 2, 1, 3)
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
-
- // cur = KQV_merged.contiguous().view(n_embd, N)
- cur = ggml_cpy(ctx0,
- KQV_merged,
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
-
- // projection
- {
- cur = ggml_mul_mat(ctx0,
- model.layers[il].c_attn_proj_w,
- cur);
-
- cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur), cur);
- }
- }
-
- struct ggml_tensor * inpFF = cur;
-
- // feed-forward network
- // this is independent of the self-attention result, so it could be done in parallel to the self-attention
- {
- // post attention layer norm
- // note here we pass inpL instead of cur
- {
- cur = ggml_norm(ctx0, inpL);
-
- cur = ggml_add(ctx0,
- ggml_mul(ctx0,
- ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),
- cur),
- ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
- }
-
- cur = ggml_mul_mat(ctx0,
- model.layers[il].c_mlp_fc_w,
- cur);
-
- cur = ggml_add(ctx0,
- ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),
- cur);
-
- // GELU activation
- cur = ggml_gelu(ctx0, cur);
-
- // projection
- // cur = proj_w*cur + proj_b
- cur = ggml_mul_mat(ctx0,
- model.layers[il].c_mlp_proj_w,
- cur);
-
- cur = ggml_add(ctx0,
- ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),
- cur);
- }
-
- // layer input + FF
- cur = ggml_add(ctx0, cur, inpFF);
-
- // input for next layer
- inpL = ggml_add(ctx0, cur, inpL);
- }
-
- // norm
- {
- inpL = ggml_norm(ctx0, inpL);
-
- // inpL = ln_f_g*inpL + ln_f_b
- inpL = ggml_add(ctx0,
- ggml_mul(ctx0,
- ggml_repeat(ctx0, model.ln_f_g, inpL),
- inpL),
- ggml_repeat(ctx0, model.ln_f_b, inpL));
- }
-
- // lm_head
- {
- inpL = ggml_mul_mat(ctx0, model.lmh_g, inpL);
-
- //inpL = ggml_add(ctx0,
- // ggml_repeat(ctx0, model.lmh_b, inpL),
- // inpL);
- }
-
- // logits -> probs
- //inpL = ggml_soft_max(ctx0, inpL);
-
- // run the computation
- ggml_build_forward_expand(&gf, inpL);
- ggml_graph_compute (ctx0, &gf);
-
- //if (n_past%100 == 0) {
- // ggml_graph_print (&gf);
- // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
- //}
-
- //embd_w.resize(n_vocab*N);
- //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
-
- // return result for just the last token
- embd_w.resize(n_vocab);
- memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
-
- if (mem_per_token == 0) {
- mem_per_token = ggml_used_mem(ctx0)/N;
- }
- //printf("used_mem = %zu\n", ggml_used_mem(ctx0));
-
- ggml_free(ctx0);
-
- return true;
-}
-
-int main(int argc, char ** argv) {
- const int64_t t_main_start_us = ggml_time_us();
-
- gpt_params params;
- params.model = "models/stablelm-base-alpha-3b/ggml-model-f16.bin";
-
- if (gpt_params_parse(argc, argv, params) == false) {
- return 1;
- }
-
- if (params.seed < 0) {
- params.seed = time(NULL);
- }
-
- printf("%s: seed = %d\n", __func__, params.seed);
-
- std::mt19937 rng(params.seed);
- if (params.prompt.empty()) {
- if( !isatty(STDIN_FILENO) ){
- std::string line;
- while( std::getline(std::cin, line) ){
- params.prompt = params.prompt + "\n" + line;
- }
- } else {
- params.prompt = gpt_random_prompt(rng);
- }
- }
-
- int64_t t_load_us = 0;
-
- gpt_vocab vocab;
- stablelm_model model;
-
- // load the model
- {
- const int64_t t_start_us = ggml_time_us();
-
- if (!stablelm_model_load(params.model, model, vocab)) {
- fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
- return 1;
- }
-
- t_load_us = ggml_time_us() - t_start_us;
- }
-
- int n_past = 0;
-
- int64_t t_sample_us = 0;
- int64_t t_predict_us = 0;
-
- std::vector<float> logits;
-
- // tokenize the prompt
- std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
-
- params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
-
- printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
- for (int i = 0; i < embd_inp.size(); i++) {
- printf("%s: token[%d] = %6d, %s\n", __func__, i, embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
- }
- printf("\n");
-
- std::vector<gpt_vocab::id> embd;
-
- // determine the required inference memory per token:
- size_t mem_per_token = 0;
- stablelm_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
-
- for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
- // predict
- if (embd.size() > 0) {
- const int64_t t_start_us = ggml_time_us();
-
- if (!stablelm_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
- printf("Failed to predict\n");
- return 1;
- }
-
- t_predict_us += ggml_time_us() - t_start_us;
- }
-
- n_past += embd.size();
- embd.clear();
-
- if (i >= embd_inp.size()) {
- // sample next token
- const int top_k = params.top_k;
- const float top_p = params.top_p;
- const float temp = params.temp;
-
- const int n_vocab = model.hparams.n_vocab;
-
- gpt_vocab::id id = 0;
-
- {
- const int64_t t_start_sample_us = ggml_time_us();
-
- id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
-
- t_sample_us += ggml_time_us() - t_start_sample_us;
- }
-
- // add it to the context
- embd.push_back(id);
- } else {
- // if here, it means we are still processing the input prompt
- for (int k = i; k < embd_inp.size(); k++) {
- embd.push_back(embd_inp[k]);
- if (embd.size() > params.n_batch) {
- break;
- }
- }
- i += embd.size() - 1;
- }
-
- // display text
- for (auto id : embd) {
- printf("%s", vocab.id_to_token[id].c_str());
- }
- fflush(stdout);
-
- // end of text token
- if (embd.back() == 0) {
- break;
- }
- }
-
- // report timing
- {
- const int64_t t_main_end_us = ggml_time_us();
-
- printf("\n\n");
- printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
- printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
- printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
- printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);
- printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
- }
-
- ggml_free(model.ctx);
-
- return 0;
-}
+++ /dev/null
-#include "ggml/ggml.h"
-
-#include "common.h"
-#include "common-ggml.h"
-
-#include <cassert>
-#include <cmath>
-#include <cstdio>
-#include <cstring>
-#include <fstream>
-#include <map>
-#include <string>
-#include <vector>
-#include <regex>
-
-// default hparams (StableLM 3B)
-struct stablelm_hparams {
- int32_t n_vocab = 50257;
- int32_t n_ctx = 4096;
- int32_t n_embd = 4096;
- int32_t n_head = 32;
- int32_t n_layer = 16;
- int32_t n_rot = 32; // 0.25 * (n_embd / n_head)
- int32_t ftype = 1;
-};
-
-// quantize a model
-bool stablelm_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) {
- gpt_vocab vocab;
-
- printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
-
- auto finp = std::ifstream(fname_inp, std::ios::binary);
- if (!finp) {
- fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str());
- return false;
- }
-
- auto fout = std::ofstream(fname_out, std::ios::binary);
- if (!fout) {
- fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
- return false;
- }
-
- // verify magic
- {
- uint32_t magic;
- finp.read((char *) &magic, sizeof(magic));
- if (magic != 0x67676d6c) {
- fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
- return false;
- }
-
- fout.write((char *) &magic, sizeof(magic));
- }
-
- stablelm_hparams hparams;
-
- // load hparams
- {
- finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
- finp.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
- finp.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
- finp.read((char *) &hparams.n_head, sizeof(hparams.n_head));
- finp.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
- finp.read((char *) &hparams.n_rot, sizeof(hparams.n_rot));
- finp.read((char *) &hparams.ftype, sizeof(hparams.ftype));
-
- printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
- printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
- printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
- printf("%s: n_head = %d\n", __func__, hparams.n_head);
- printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
- printf("%s: ftype = %d\n", __func__, hparams.ftype);
-
- fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
- fout.write((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
- fout.write((char *) &hparams.n_embd, sizeof(hparams.n_embd));
- fout.write((char *) &hparams.n_head, sizeof(hparams.n_head));
- fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer));
- fout.write((char *) &hparams.n_rot, sizeof(hparams.n_rot));
- fout.write((char *) &ftype, sizeof(hparams.ftype));
- }
-
- // load vocab
- {
- const int32_t n_vocab = hparams.n_vocab;
-
- std::string word;
- for (int i = 0; i < n_vocab; i++) {
- uint32_t len;
- finp.read ((char *) &len, sizeof(len));
- fout.write((char *) &len, sizeof(len));
-
- word.resize(len);
- finp.read ((char *) word.data(), len);
- fout.write((char *) word.data(), len);
-
- vocab.token_to_id[word] = i;
- vocab.id_to_token[i] = word;
- }
- }
-
- // regexes of tensor names to be quantized
- const std::vector<std::string> to_quant = {
- ".*weight",
- };
-
- if (!ggml_common_quantize_0(finp, fout, ftype, to_quant, {})) {
- fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__, fname_inp.c_str());
- return false;
- }
-
- finp.close();
- fout.close();
-
- return true;
-}
-
-// usage:
-// ./stablelm2-quantize models/stablelm2-117M/ggml-model.bin models/stablelm2-117M/ggml-model-quant.bin type
-//
-int main(int argc, char ** argv) {
- if (argc != 4) {
- fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
- ggml_print_ftypes(stderr);
- return 1;
- }
-
- // needed to initialize f16 tables
- {
- struct ggml_init_params params = { 0, NULL, false };
- struct ggml_context * ctx = ggml_init(params);
- ggml_free(ctx);
- }
-
- const std::string fname_inp = argv[1];
- const std::string fname_out = argv[2];
-
- const ggml_ftype ftype = ggml_parse_ftype(argv[3]);
-
- const int64_t t_main_start_us = ggml_time_us();
-
- int64_t t_quantize_us = 0;
-
- // load the model
- {
- const int64_t t_start_us = ggml_time_us();
-
- if (!stablelm_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) {
- fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
- return 1;
- }
-
- t_quantize_us = ggml_time_us() - t_start_us;
- }
-
- // report timing
- {
- const int64_t t_main_end_us = ggml_time_us();
-
- printf("\n");
- printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f);
- printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
- }
-
- return 0;
-}