]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
dolly-v2 : par_res and neox changes (#167)
authorMichael Verrilli <redacted>
Sat, 20 May 2023 14:12:24 +0000 (10:12 -0400)
committerGitHub <redacted>
Sat, 20 May 2023 14:12:24 +0000 (17:12 +0300)
* dolly-v2 example: par_res and neox changes

* Update examples/dolly-v2/quantize.cpp

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/dolly-v2/convert-h5-to-ggml.py
examples/dolly-v2/main.cpp
examples/dolly-v2/quantize.cpp

index ecbe2faddca8f962a61008d2513c3e0012be6215..0019810e28e1ff2a7ca7ad2795e4fb1e2eb41a1a 100644 (file)
@@ -1,7 +1,6 @@
 import sys
 import struct
 import json
-import torch
 import numpy as np
 
 from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -59,6 +58,7 @@ fout.write(struct.pack("i", hparams["hidden_size"]))
 fout.write(struct.pack("i", hparams["num_attention_heads"]))
 fout.write(struct.pack("i", hparams["num_hidden_layers"]))
 fout.write(struct.pack("i", int(hparams["rotary_pct"]*(hparams["hidden_size"]//hparams["num_attention_heads"]))))
+fout.write(struct.pack("i", hparams["use_parallel_residual"]))
 fout.write(struct.pack("i", ftype))
 
 # TODO: temporary hack to not deal with implementing the tokenizer
index 305093fe99c5d3ebdad876c4b71cf81dbd8dcede..235e35f27b9d52e8fcd50ce8e3fb0652c3c93c36 100644 (file)
@@ -23,6 +23,7 @@ struct dollyv2_hparams {
     int32_t n_head  = 32;    // model.config.num_attention_heads
     int32_t n_layer = 32;    // model.config.num_hidden_layers
     int32_t n_rot   = 20;    // rotary_pct[25%] * (n_embd / n_head)
+    int32_t par_res = 1; // 1 = true, 0 = false
     int32_t ftype   = GGML_FTYPE_MOSTLY_F16;
 };
 
@@ -113,6 +114,7 @@ bool dollyv2_model_load(const std::string & fname, dollyv2_model & model, gpt_vo
         fin.read((char *) &hparams.n_head,  sizeof(hparams.n_head));
         fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
         fin.read((char *) &hparams.n_rot,   sizeof(hparams.n_rot));
+        fin.read((char *) &hparams.par_res, sizeof(hparams.par_res));        
         fin.read((char *) &hparams.ftype,   sizeof(hparams.ftype));
 
         const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;
@@ -123,6 +125,7 @@ bool dollyv2_model_load(const std::string & fname, dollyv2_model & model, gpt_vo
         printf("%s: n_head  = %d\n", __func__, hparams.n_head);
         printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
         printf("%s: n_rot   = %d\n", __func__, hparams.n_rot);
+        printf("%s: par_res = %d\n", __func__, hparams.par_res);
         printf("%s: ftype   = %d\n", __func__, hparams.ftype);
         printf("%s: qntvr   = %d\n", __func__, qntvr);
 
@@ -390,6 +393,42 @@ bool dollyv2_model_load(const std::string & fname, dollyv2_model & model, gpt_vo
     return true;
 }
 
+// feed-forward network
+ggml_tensor * gpt_neox_ff(
+        const dollyv2_layer &layer,
+        ggml_context * ctx0,
+        ggml_tensor * inp) {
+    ggml_tensor * cur = ggml_norm(ctx0, inp);
+
+    cur = ggml_add(ctx0,
+        ggml_mul(ctx0,
+            ggml_repeat(ctx0, layer.ln_2_g, cur),
+            cur),
+        ggml_repeat(ctx0, layer.ln_2_b, cur));
+
+    cur = ggml_mul_mat(ctx0,
+            layer.c_mlp_fc_w,
+            cur);
+
+    cur = ggml_add(ctx0,
+            ggml_repeat(ctx0, layer.c_mlp_fc_b, cur),
+            cur);
+
+    // GELU activation
+    cur = ggml_gelu(ctx0, cur);
+
+    // projection
+    // cur = proj_w*cur + proj_b
+    cur = ggml_mul_mat(ctx0,
+            layer.c_mlp_proj_w,
+            cur);
+
+    cur = ggml_add(ctx0,
+            ggml_repeat(ctx0, layer.c_mlp_proj_b, cur),
+            cur);
+    return cur;
+}
+
 // evaluate the transformer
 //
 //   - model:     the model
@@ -554,50 +593,27 @@ bool dollyv2_eval(
             }
         }
 
-        struct ggml_tensor * inpFF = cur;
-
-        // feed-forward network
-        // this is independent of the self-attention result, so it could be done in parallel to the self-attention
-        {
-            // post attention layer norm
-            // note here we pass inpL instead of cur
-            {
-                cur = ggml_norm(ctx0, inpL);
+        if (hparams.par_res == 0) {
+            struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpL);
 
-                cur = ggml_add(ctx0,
-                    ggml_mul(ctx0,
-                        ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),
-                        cur),
-                    ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
-            }
+            cur = gpt_neox_ff(model.layers[il], ctx0, inpFF);
 
-            cur = ggml_mul_mat(ctx0,
-                    model.layers[il].c_mlp_fc_w,
-                    cur);
+            // input for next layer
+            inpL = ggml_add(ctx0, cur, inpFF);
+        } else {
+            struct ggml_tensor * inpFF = cur;
 
-            cur = ggml_add(ctx0,
-                    ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),
-                    cur);
+            // this is independent of the self-attention result, so it could be done in parallel to the self-attention
+            // note here we pass inpL instead of cur
+            cur = gpt_neox_ff(model.layers[il], ctx0, inpL);
 
-            // GELU activation
-            cur = ggml_gelu(ctx0, cur);
+            // layer input + FF
+            cur  = ggml_add(ctx0, cur, inpFF);
 
-            // projection
-            // cur = proj_w*cur + proj_b
-            cur = ggml_mul_mat(ctx0,
-                    model.layers[il].c_mlp_proj_w,
-                    cur);
-
-            cur = ggml_add(ctx0,
-                    ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),
-                    cur);
+            // input for next layer
+            inpL = ggml_add(ctx0, cur, inpL);
         }
-
-        // layer input + FF
-        cur  = ggml_add(ctx0, cur, inpFF);
-
-        // input for next layer
-        inpL = ggml_add(ctx0, cur, inpL);
+    
     }
 
     // norm
index 83f11e7574cf14c33e893a2dd434a4c5465d7393..83f7572749740a56bd349d01c98f0c9a3f2d0d8e 100644 (file)
 #include <vector>
 #include <regex>
 
-// default hparams (StableLM 3B)
-struct stablelm_hparams {
-    int32_t n_vocab = 50257;
-    int32_t n_ctx   = 4096;
-    int32_t n_embd  = 4096;
-    int32_t n_head  = 32;
-    int32_t n_layer = 16;
-    int32_t n_rot   = 32; // 0.25 * (n_embd / n_head)
-    int32_t ftype   = 1;
+// default hparams (dollyv2 3B)
+struct dollyv2_hparams {
+    int32_t n_vocab = 50254; // tokenizer.vocab_size
+    int32_t n_ctx   = 2048;  // model.config.max_position_embeddings
+    int32_t n_embd  = 2560;  // model.config.hidden_size
+    int32_t n_head  = 32;    // model.config.num_attention_heads
+    int32_t n_layer = 32;    // model.config.num_hidden_layers
+    int32_t n_rot   = 20;    // rotary_pct[25%] * (n_embd / n_head)
+    int32_t par_res = 1; // 1 = true, 0 = false
+    int32_t ftype   = GGML_FTYPE_MOSTLY_F16;
 };
 
 // quantize a model
-bool stablelm_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) {
+bool dollyv2_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) {
     gpt_vocab vocab;
 
     printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
@@ -54,7 +55,7 @@ bool stablelm_model_quantize(const std::string & fname_inp, const std::string &
         fout.write((char *) &magic, sizeof(magic));
     }
 
-    stablelm_hparams hparams;
+    dollyv2_hparams hparams;
 
     // load hparams
     {
@@ -64,6 +65,7 @@ bool stablelm_model_quantize(const std::string & fname_inp, const std::string &
         finp.read((char *) &hparams.n_head,  sizeof(hparams.n_head));
         finp.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
         finp.read((char *) &hparams.n_rot,   sizeof(hparams.n_rot));
+        finp.read((char *) &hparams.par_res, sizeof(hparams.par_res));
         finp.read((char *) &hparams.ftype,   sizeof(hparams.ftype));
 
         const int32_t qntvr_src =    hparams.ftype / GGML_QNT_VERSION_FACTOR;
@@ -74,6 +76,7 @@ bool stablelm_model_quantize(const std::string & fname_inp, const std::string &
         printf("%s: n_embd      = %d\n", __func__, hparams.n_embd);
         printf("%s: n_head      = %d\n", __func__, hparams.n_head);
         printf("%s: n_layer     = %d\n", __func__, hparams.n_layer);
+        printf("%s: par_res     = %d\n", __func__, hparams.par_res);
         printf("%s: ftype (src) = %d\n", __func__, hparams.ftype);
         printf("%s: qntvr (src) = %d\n", __func__, qntvr_src);
         printf("%s: ftype (dst) = %d\n", __func__, ftype_dst);
@@ -85,6 +88,7 @@ bool stablelm_model_quantize(const std::string & fname_inp, const std::string &
         fout.write((char *) &hparams.n_head,  sizeof(hparams.n_head));
         fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer));
         fout.write((char *) &hparams.n_rot,   sizeof(hparams.n_rot));
+        fout.write((char *) &hparams.par_res, sizeof(hparams.par_res));
         fout.write((char *) &ftype_dst,       sizeof(ftype_dst));
     }
 
@@ -124,7 +128,7 @@ bool stablelm_model_quantize(const std::string & fname_inp, const std::string &
 }
 
 // usage:
-//  ./stablelm2-quantize models/stablelm2-117M/ggml-model.bin models/stablelm2-117M/ggml-model-quant.bin type
+//  ./dollyv2-quantize models/dolly-v2-3B/ggml-model.bin models/dolly-v2-3B/ggml-model-quant.bin type
 //
 int main(int argc, char ** argv) {
     if (argc != 4) {
@@ -153,7 +157,7 @@ int main(int argc, char ** argv) {
     {
         const int64_t t_start_us = ggml_time_us();
 
-        if (!stablelm_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) {
+        if (!dollyv2_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) {
             fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
             return 1;
         }