From: Michael Verrilli Date: Sat, 20 May 2023 14:12:24 +0000 (-0400) Subject: dolly-v2 : par_res and neox changes (#167) X-Git-Tag: upstream/0.0.1642~1459 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=c655bc9f8813aa9686e64cf642a17b79d5d40166;p=pkg%2Fggml%2Fsources%2Fggml dolly-v2 : par_res and neox changes (#167) * dolly-v2 example: par_res and neox changes * Update examples/dolly-v2/quantize.cpp --------- Co-authored-by: Georgi Gerganov --- diff --git a/examples/dolly-v2/convert-h5-to-ggml.py b/examples/dolly-v2/convert-h5-to-ggml.py index ecbe2fad..0019810e 100644 --- a/examples/dolly-v2/convert-h5-to-ggml.py +++ b/examples/dolly-v2/convert-h5-to-ggml.py @@ -1,7 +1,6 @@ import sys import struct import json -import torch import numpy as np from transformers import AutoModelForCausalLM, AutoTokenizer @@ -59,6 +58,7 @@ fout.write(struct.pack("i", hparams["hidden_size"])) fout.write(struct.pack("i", hparams["num_attention_heads"])) fout.write(struct.pack("i", hparams["num_hidden_layers"])) fout.write(struct.pack("i", int(hparams["rotary_pct"]*(hparams["hidden_size"]//hparams["num_attention_heads"])))) +fout.write(struct.pack("i", hparams["use_parallel_residual"])) fout.write(struct.pack("i", ftype)) # TODO: temporary hack to not deal with implementing the tokenizer diff --git a/examples/dolly-v2/main.cpp b/examples/dolly-v2/main.cpp index 305093fe..235e35f2 100644 --- a/examples/dolly-v2/main.cpp +++ b/examples/dolly-v2/main.cpp @@ -23,6 +23,7 @@ struct dollyv2_hparams { int32_t n_head = 32; // model.config.num_attention_heads int32_t n_layer = 32; // model.config.num_hidden_layers int32_t n_rot = 20; // rotary_pct[25%] * (n_embd / n_head) + int32_t par_res = 1; // 1 = true, 0 = false int32_t ftype = GGML_FTYPE_MOSTLY_F16; }; @@ -113,6 +114,7 @@ bool dollyv2_model_load(const std::string & fname, dollyv2_model & model, gpt_vo fin.read((char *) &hparams.n_head, sizeof(hparams.n_head)); fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer)); fin.read((char *) &hparams.n_rot, sizeof(hparams.n_rot)); + fin.read((char *) &hparams.par_res, sizeof(hparams.par_res)); fin.read((char *) &hparams.ftype, sizeof(hparams.ftype)); const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR; @@ -123,6 +125,7 @@ bool dollyv2_model_load(const std::string & fname, dollyv2_model & model, gpt_vo printf("%s: n_head = %d\n", __func__, hparams.n_head); printf("%s: n_layer = %d\n", __func__, hparams.n_layer); printf("%s: n_rot = %d\n", __func__, hparams.n_rot); + printf("%s: par_res = %d\n", __func__, hparams.par_res); printf("%s: ftype = %d\n", __func__, hparams.ftype); printf("%s: qntvr = %d\n", __func__, qntvr); @@ -390,6 +393,42 @@ bool dollyv2_model_load(const std::string & fname, dollyv2_model & model, gpt_vo return true; } +// feed-forward network +ggml_tensor * gpt_neox_ff( + const dollyv2_layer &layer, + ggml_context * ctx0, + ggml_tensor * inp) { + ggml_tensor * cur = ggml_norm(ctx0, inp); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, layer.ln_2_g, cur), + cur), + ggml_repeat(ctx0, layer.ln_2_b, cur)); + + cur = ggml_mul_mat(ctx0, + layer.c_mlp_fc_w, + cur); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.c_mlp_fc_b, cur), + cur); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + // projection + // cur = proj_w*cur + proj_b + cur = ggml_mul_mat(ctx0, + layer.c_mlp_proj_w, + cur); + + cur = ggml_add(ctx0, + ggml_repeat(ctx0, layer.c_mlp_proj_b, cur), + cur); + return cur; +} + // evaluate the transformer // // - model: the model @@ -554,50 +593,27 @@ bool dollyv2_eval( } } - struct ggml_tensor * inpFF = cur; - - // feed-forward network - // this is independent of the self-attention result, so it could be done in parallel to the self-attention - { - // post attention layer norm - // note here we pass inpL instead of cur - { - cur = ggml_norm(ctx0, inpL); + if (hparams.par_res == 0) { + struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpL); - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.layers[il].ln_2_g, cur), - cur), - ggml_repeat(ctx0, model.layers[il].ln_2_b, cur)); - } + cur = gpt_neox_ff(model.layers[il], ctx0, inpFF); - cur = ggml_mul_mat(ctx0, - model.layers[il].c_mlp_fc_w, - cur); + // input for next layer + inpL = ggml_add(ctx0, cur, inpFF); + } else { + struct ggml_tensor * inpFF = cur; - cur = ggml_add(ctx0, - ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur), - cur); + // this is independent of the self-attention result, so it could be done in parallel to the self-attention + // note here we pass inpL instead of cur + cur = gpt_neox_ff(model.layers[il], ctx0, inpL); - // GELU activation - cur = ggml_gelu(ctx0, cur); + // layer input + FF + cur = ggml_add(ctx0, cur, inpFF); - // projection - // cur = proj_w*cur + proj_b - cur = ggml_mul_mat(ctx0, - model.layers[il].c_mlp_proj_w, - cur); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur), - cur); + // input for next layer + inpL = ggml_add(ctx0, cur, inpL); } - - // layer input + FF - cur = ggml_add(ctx0, cur, inpFF); - - // input for next layer - inpL = ggml_add(ctx0, cur, inpL); + } // norm diff --git a/examples/dolly-v2/quantize.cpp b/examples/dolly-v2/quantize.cpp index 83f11e75..83f75727 100644 --- a/examples/dolly-v2/quantize.cpp +++ b/examples/dolly-v2/quantize.cpp @@ -13,19 +13,20 @@ #include #include -// default hparams (StableLM 3B) -struct stablelm_hparams { - int32_t n_vocab = 50257; - int32_t n_ctx = 4096; - int32_t n_embd = 4096; - int32_t n_head = 32; - int32_t n_layer = 16; - int32_t n_rot = 32; // 0.25 * (n_embd / n_head) - int32_t ftype = 1; +// default hparams (dollyv2 3B) +struct dollyv2_hparams { + int32_t n_vocab = 50254; // tokenizer.vocab_size + int32_t n_ctx = 2048; // model.config.max_position_embeddings + int32_t n_embd = 2560; // model.config.hidden_size + int32_t n_head = 32; // model.config.num_attention_heads + int32_t n_layer = 32; // model.config.num_hidden_layers + int32_t n_rot = 20; // rotary_pct[25%] * (n_embd / n_head) + int32_t par_res = 1; // 1 = true, 0 = false + int32_t ftype = GGML_FTYPE_MOSTLY_F16; }; // quantize a model -bool stablelm_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) { +bool dollyv2_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) { gpt_vocab vocab; printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str()); @@ -54,7 +55,7 @@ bool stablelm_model_quantize(const std::string & fname_inp, const std::string & fout.write((char *) &magic, sizeof(magic)); } - stablelm_hparams hparams; + dollyv2_hparams hparams; // load hparams { @@ -64,6 +65,7 @@ bool stablelm_model_quantize(const std::string & fname_inp, const std::string & finp.read((char *) &hparams.n_head, sizeof(hparams.n_head)); finp.read((char *) &hparams.n_layer, sizeof(hparams.n_layer)); finp.read((char *) &hparams.n_rot, sizeof(hparams.n_rot)); + finp.read((char *) &hparams.par_res, sizeof(hparams.par_res)); finp.read((char *) &hparams.ftype, sizeof(hparams.ftype)); const int32_t qntvr_src = hparams.ftype / GGML_QNT_VERSION_FACTOR; @@ -74,6 +76,7 @@ bool stablelm_model_quantize(const std::string & fname_inp, const std::string & printf("%s: n_embd = %d\n", __func__, hparams.n_embd); printf("%s: n_head = %d\n", __func__, hparams.n_head); printf("%s: n_layer = %d\n", __func__, hparams.n_layer); + printf("%s: par_res = %d\n", __func__, hparams.par_res); printf("%s: ftype (src) = %d\n", __func__, hparams.ftype); printf("%s: qntvr (src) = %d\n", __func__, qntvr_src); printf("%s: ftype (dst) = %d\n", __func__, ftype_dst); @@ -85,6 +88,7 @@ bool stablelm_model_quantize(const std::string & fname_inp, const std::string & fout.write((char *) &hparams.n_head, sizeof(hparams.n_head)); fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer)); fout.write((char *) &hparams.n_rot, sizeof(hparams.n_rot)); + fout.write((char *) &hparams.par_res, sizeof(hparams.par_res)); fout.write((char *) &ftype_dst, sizeof(ftype_dst)); } @@ -124,7 +128,7 @@ bool stablelm_model_quantize(const std::string & fname_inp, const std::string & } // usage: -// ./stablelm2-quantize models/stablelm2-117M/ggml-model.bin models/stablelm2-117M/ggml-model-quant.bin type +// ./dollyv2-quantize models/dolly-v2-3B/ggml-model.bin models/dolly-v2-3B/ggml-model-quant.bin type // int main(int argc, char ** argv) { if (argc != 4) { @@ -153,7 +157,7 @@ int main(int argc, char ** argv) { { const int64_t t_start_us = ggml_time_us(); - if (!stablelm_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) { + if (!dollyv2_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) { fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); return 1; }