]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync with sam.cpp (add SAM Vit H & L model support and fix SAM's output) ...
authorYavor Ivanov <redacted>
Mon, 28 Aug 2023 08:33:55 +0000 (11:33 +0300)
committerGitHub <redacted>
Mon, 28 Aug 2023 08:33:55 +0000 (11:33 +0300)
* Add support for Vit H and Vit L SAM model checkpoints

* Add "eps" argument to ggml_norm and fix all examples

* Fix bias addition for ConvTranspose2D layers in SAM example

* Fix build when GGML_ALLOCATOR_DEBUG is enabled

* Use op params for the stride in CONV_TRANSPOSE_2D

Needed in order for the operation to work with ggml-alloc as
the previous implementation used ggml_new_i32, which uses strach buffers

We should remove new_i32 and new_f32 I think. new_f32 is used in a lot
of places.

14 files changed:
examples/dolly-v2/main.cpp
examples/gpt-2/main.cpp
examples/gpt-j/main.cpp
examples/gpt-neox/main.cpp
examples/mpt/main.cpp
examples/replit/main.cpp
examples/sam/convert-pth-to-ggml.py
examples/sam/main.cpp
examples/starcoder/main.cpp
examples/starcoder/starcoder-mmap.cpp
examples/whisper/whisper.cpp
include/ggml/ggml.h
src/ggml-alloc.c
src/ggml.c

index aa06a6d7855af5c1c01141e35df33658ea49d267..a09cad6156352fb104c7c561c62569496da6f8cf 100644 (file)
@@ -39,6 +39,7 @@ struct dollyv2_hparams {
     int32_t n_rot   = 20;    // rotary_pct[25%] * (n_embd / n_head)
     int32_t par_res = 1; // 1 = true, 0 = false
     int32_t ftype   = GGML_FTYPE_MOSTLY_F16;
+    float   eps     = 1e-5;
 };
 
 const std::string INSTRUCTION_KEY = "### Instruction:";
@@ -412,10 +413,11 @@ bool dollyv2_model_load(const std::string & fname, dollyv2_model & model, gpt_vo
 
 // feed-forward network
 ggml_tensor * gpt_neox_ff(
-        const dollyv2_layer &layer,
-        ggml_context * ctx0,
-        ggml_tensor * inp) {
-    ggml_tensor * cur = ggml_norm(ctx0, inp);
+        const dollyv2_layer & layer,
+        ggml_context        * ctx0,
+        ggml_tensor         * inp,
+        float                 eps) {
+    ggml_tensor * cur = ggml_norm(ctx0, inp, eps);
 
     cur = ggml_add(ctx0,
         ggml_mul(ctx0,
@@ -509,7 +511,7 @@ bool dollyv2_eval(
         // self-attention
         {
             {
-                cur = ggml_norm(ctx0, inpL);
+                cur = ggml_norm(ctx0, inpL, hparams.eps);
 
                 cur = ggml_add(ctx0,
                         ggml_mul(ctx0,
@@ -612,7 +614,7 @@ bool dollyv2_eval(
         if (hparams.par_res == 0) {
             struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpL);
 
-            cur = gpt_neox_ff(model.layers[il], ctx0, inpFF);
+            cur = gpt_neox_ff(model.layers[il], ctx0, inpFF, hparams.eps);
 
             // input for next layer
             inpL = ggml_add(ctx0, cur, inpFF);
@@ -621,7 +623,7 @@ bool dollyv2_eval(
 
             // this is independent of the self-attention result, so it could be done in parallel to the self-attention
             // note here we pass inpL instead of cur
-            cur = gpt_neox_ff(model.layers[il], ctx0, inpL);
+            cur = gpt_neox_ff(model.layers[il], ctx0, inpL, hparams.eps);
 
             // layer input + FF
             cur  = ggml_add(ctx0, cur, inpFF);
@@ -634,7 +636,7 @@ bool dollyv2_eval(
 
     // norm
     {
-        inpL = ggml_norm(ctx0, inpL);
+        inpL = ggml_norm(ctx0, inpL, hparams.eps);
 
         // inpL = ln_f_g*inpL + ln_f_b
         inpL = ggml_add(ctx0,
index 87917e3d04aa4f47849952b4b4b900ff87d09018..ed405002a720c28d659fbaa6f167a0a4725c3b93 100644 (file)
@@ -25,6 +25,7 @@ struct gpt2_hparams {
     int32_t n_head  = 12;
     int32_t n_layer = 12;
     int32_t ftype   = 1;
+    float   eps     = 1e-5;
 };
 
 struct gpt2_layer {
@@ -444,7 +445,7 @@ struct ggml_cgraph * gpt2_graph(
         // norm
         {
             // [ 768, N]
-            cur = ggml_norm(ctx0, inpL);
+            cur = ggml_norm(ctx0, inpL, hparams.eps);
 
             // cur = ln_1_g*cur + ln_1_b
             // [ 768, N]
@@ -590,7 +591,7 @@ struct ggml_cgraph * gpt2_graph(
         {
             // norm
             {
-                cur = ggml_norm(ctx0, inpFF);
+                cur = ggml_norm(ctx0, inpFF, hparams.eps);
 
                 // cur = ln_2_g*cur + ln_2_b
                 // [ 768, N]
@@ -645,7 +646,7 @@ struct ggml_cgraph * gpt2_graph(
     // norm
     {
         // [ 768, N]
-        inpL = ggml_norm(ctx0, inpL);
+        inpL = ggml_norm(ctx0, inpL, hparams.eps);
 
         // inpL = ln_f_g*inpL + ln_f_b
         // [ 768, N]
index a0c1eea4f93979fa98f6d58e63f5e161ff6cbe8a..b23ad3d24c41aea2281b6e46122ff3082b0b9f96 100644 (file)
@@ -26,6 +26,7 @@ struct gptj_hparams {
     int32_t n_layer = 28;
     int32_t n_rot   = 64;
     int32_t ftype   = 1;
+    float   eps     = 1e-5;
 };
 
 struct gptj_layer {
@@ -437,7 +438,7 @@ bool gptj_eval(
 
         // norm
         {
-            cur = ggml_norm(ctx0, inpL);
+            cur = ggml_norm(ctx0, inpL, hparams.eps);
 
             // cur = ln_1_g*cur + ln_1_b
             cur = ggml_add(ctx0,
@@ -559,7 +560,7 @@ bool gptj_eval(
 
     // norm
     {
-        inpL = ggml_norm(ctx0, inpL);
+        inpL = ggml_norm(ctx0, inpL, hparams.eps);
 
         // inpL = ln_f_g*inpL + ln_f_b
         inpL = ggml_add(ctx0,
index af6129a1d2536843d7787a973db0d032f7d8ad20..80ee6643187a0ea209202026bc0c91fdb8cda015 100644 (file)
@@ -27,6 +27,7 @@ struct gpt_neox_hparams {
     int32_t n_rot   = 32; // rotary_pct * (n_embd / n_head)
     int32_t par_res = 1; // 1 = true, 0 = false
     int32_t ftype   = 1;
+    float   eps     = 1e-5;
 };
 
 struct gpt_neox_layer {
@@ -384,10 +385,11 @@ bool gpt_neox_model_load(const std::string & fname, gpt_neox_model & model, gpt_
 
 // feed-forward network
 ggml_tensor * gpt_neox_ff(
-        const gpt_neox_layer &layer,
-        ggml_context * ctx0,
-        ggml_tensor * inp) {
-    ggml_tensor * cur = ggml_norm(ctx0, inp);
+        const gpt_neox_layer & layer,
+        ggml_context         * ctx0,
+        ggml_tensor          * inp,
+        float                  eps) {
+    ggml_tensor * cur = ggml_norm(ctx0, inp, eps);
 
     cur = ggml_add(ctx0,
         ggml_mul(ctx0,
@@ -491,7 +493,7 @@ bool gpt_neox_eval(
         // self-attention
         {
             {
-                cur = ggml_norm(ctx0, inpL);
+                cur = ggml_norm(ctx0, inpL, hparams.eps);
 
                 cur = ggml_add(ctx0,
                         ggml_mul(ctx0,
@@ -596,7 +598,7 @@ bool gpt_neox_eval(
         if (hparams.par_res == 0) {
             struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpL);
 
-            cur = gpt_neox_ff(model.layers[il], ctx0, inpFF);
+            cur = gpt_neox_ff(model.layers[il], ctx0, inpFF, hparams.eps);
 
             // input for next layer
             inpL = ggml_add(ctx0, cur, inpFF);
@@ -605,7 +607,7 @@ bool gpt_neox_eval(
 
             // this is independent of the self-attention result, so it could be done in parallel to the self-attention
             // note here we pass inpL instead of cur
-            cur = gpt_neox_ff(model.layers[il], ctx0, inpL);
+            cur = gpt_neox_ff(model.layers[il], ctx0, inpL, hparams.eps);
 
             // layer input + FF
             cur  = ggml_add(ctx0, cur, inpFF);
@@ -619,7 +621,7 @@ bool gpt_neox_eval(
 
     // norm
     {
-        inpL = ggml_norm(ctx0, inpL);
+        inpL = ggml_norm(ctx0, inpL, hparams.eps);
 
         // inpL = ln_f_g*inpL + ln_f_b
         inpL = ggml_add(ctx0,
index fb8dbb6d7662bcf6b2b067f1da329da425409ef9..2fda67cc71fd2f560a409493102176e6cca12faa 100644 (file)
@@ -465,6 +465,7 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
     const int n_head  = hparams.n_heads;
     const int n_vocab = hparams.n_vocab;
     const int n_ctx   = hparams.n_ctx;
+    const float eps   = 1e-5;
 
     static size_t buf_size = 256u * 1024 * 1024;
     static void * buf = malloc(buf_size);
@@ -513,7 +514,7 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
 
         // a = self.ln_1(x)
         {
-            cur = ggml_norm(ctx0, inpL);
+            cur = ggml_norm(ctx0, inpL, eps);
 
             cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].norm_1_weight, cur), cur);
         }
@@ -609,7 +610,7 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
 
         // m = self.ln_2(x)
         {
-            cur = ggml_norm(ctx0, inpL);
+            cur = ggml_norm(ctx0, inpL, eps);
 
             cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].norm_2_weight, cur), cur);
         }
@@ -635,7 +636,7 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
 
     // norm
     {
-        inpL = ggml_norm(ctx0, inpL);
+        inpL = ggml_norm(ctx0, inpL, eps);
         // inpL = ln_f_g*inpL
         inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.norm_f_weight, inpL), inpL);
     }
index 967abe61491c3d695209809657f224523f2d5974..3fb664d83edb04e6ab39a2e4a5569f9a302ca80c 100644 (file)
@@ -450,6 +450,7 @@ bool replit_eval(const replit_model & model, const int n_threads, const int n_pa
     const int n_head = hparams.n_heads;
     const int n_vocab = hparams.n_vocab;
     const int n_ctx = hparams.max_seq_len;
+    const float eps = 1e-5;
 
     static size_t buf_size = 256u * 1024 * 1024;
     static void * buf = malloc(buf_size);
@@ -488,7 +489,7 @@ bool replit_eval(const replit_model & model, const int n_threads, const int n_pa
 
         // a = self.ln_1(x)
         {
-            cur = ggml_norm(ctx0, inpL);
+            cur = ggml_norm(ctx0, inpL, eps);
 
             cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].norm_1_weight, cur), cur);
         }
@@ -577,7 +578,7 @@ bool replit_eval(const replit_model & model, const int n_threads, const int n_pa
 
         // m = self.ln_2(x)
         {
-            cur = ggml_norm(ctx0, inpL);
+            cur = ggml_norm(ctx0, inpL, eps);
 
             cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].norm_2_weight, cur), cur);
         }
@@ -601,7 +602,7 @@ bool replit_eval(const replit_model & model, const int n_threads, const int n_pa
 
     // norm
     {
-        inpL = ggml_norm(ctx0, inpL);
+        inpL = ggml_norm(ctx0, inpL, eps);
         // inpL = ln_f_g*inpL
         inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.norm_f_weight, inpL), inpL);
     }
index 91149c72523a3da2474c5927bbd7c39ecd03c6ce..5f97f0fb43147a3cccb35270637295875922645f 100644 (file)
@@ -36,17 +36,32 @@ if ftype < 0 or ftype > 1:
 
 fname_out = fname_out.replace(".bin", "-" + ftype_str[ftype] + ".bin")
 
+# Default params are set to sam_vit_b checkpoint
+n_enc_state = 768
+n_enc_layers = 12
+n_enc_heads = 12
+n_enc_out_chans = 256
+n_pt_embd = 4
+
 model = torch.load(fname_model, map_location="cpu")
+for k, v in model.items():
+    print(k, v.shape)
+    if k == "image_encoder.blocks.0.norm1.weight":
+        n_enc_state = v.shape[0]
 
-# TODO: determine based on model data
-# TODO: add decoder / prompt encoder if needed
-hparams = {
-    "n_enc_state":      768,
-    "n_enc_layers":      12,
-    "n_enc_heads":       12,
-    "n_enc_out_chans":  256,
+if n_enc_state == 1024: # sam_vit_l
+    n_enc_layers = 24
+    n_enc_heads  = 16
+elif n_enc_state == 1280: # sam_vit_h
+    n_enc_layers = 32
+    n_enc_heads  = 16
 
-    "n_pt_embd": 4,
+hparams = {
+    "n_enc_state":      n_enc_state,
+    "n_enc_layers":     n_enc_layers,
+    "n_enc_heads":      n_enc_heads,
+    "n_enc_out_chans":  n_enc_out_chans,
+    "n_pt_embd":        n_pt_embd,
 }
 
 print(hparams)
@@ -79,7 +94,7 @@ for k, v in model.items():
     #data = tf.train.load_variable(dir_model, name).squeeze()
     #data = v.numpy().squeeze()
     data = v.numpy()
-    n_dims = len(data.shape);
+    n_dims = len(data.shape)
 
     # for efficiency - transpose some matrices
     # "model/h.*/attn/c_attn/w"
@@ -113,7 +128,7 @@ for k, v in model.items():
     # keep it in F32 since the data is small
     if name == "image_encoder.patch_embed.proj.bias":
         data = data.reshape(1, data.shape[0], 1, 1)
-        n_dims = len(data.shape);
+        n_dims = len(data.shape)
         dshape = data.shape
 
     print("  New shape: ", dshape)
@@ -123,7 +138,7 @@ for k, v in model.items():
     fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
     for i in range(n_dims):
         fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
-    fout.write(str);
+    fout.write(str)
 
     # data
     data.tofile(fout)
index 9f6cd838290af00db9552f69cac28e23d921032c..99b039b5624fc1f9d48bb31da97c73be05c8c4c6 100644 (file)
 #include <string>
 #include <vector>
 
+// void print_t_f32(const char* title, struct ggml_tensor * t, int n = 10) {
+//     printf("%s\n", title);
+//     float * data = (float *)t->data;
+//     printf("dims: %jd %jd %jd %jd f32\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]);
+//     printf("First & Last %d elements:\n", n);
+//     for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) {
+//         printf("%.5f ", data[i]);
+//         if (i != 0 && i % t->ne[0] == 0) {
+//             printf("\n");
+//         }
+//     }
+//     printf("\n");
+//     for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) {
+//         printf("%.5f ", data[ggml_nelements(t) - n + i]);
+//         if ((ggml_nelements(t) - n + i) % t->ne[0] == 0) {
+//             printf("\n");
+//         }
+//     }
+//     printf("\n");
+//     double sum = 0.0;
+//     for (int i = 0; i < ggml_nelements(t); i++) {
+//         sum += data[i];
+//     }
+//     printf("sum:  %f\n\n", sum);
+// }
+
 // default hparams (ViT-B SAM)
 struct sam_hparams {
     int32_t n_enc_state               = 768;
@@ -29,6 +55,8 @@ struct sam_hparams {
     float   iou_threshold             = 0.88f;
     float   stability_score_threshold = 0.95f;
     float   stability_score_offset    = 1.0f;
+    float   eps                       = 1e-6f;
+    float   eps_decoder_transformer   = 1e-5f;
 
     int32_t n_enc_head_dim() const { return n_enc_state / n_enc_head; }
     int32_t n_img_size()     const { return 1024; }
@@ -1083,12 +1111,13 @@ struct ggml_tensor* sam_layer_norm_2d(
                     struct ggml_tensor  * layer,
                     int                   n_channels,
                     struct ggml_tensor  * w,
-                    struct ggml_tensor  * b) {
+                    struct ggml_tensor  * b,
+                    float                 eps) {
     // LayerNorm2d
     // normalize along channel dimmension
     // TODO: better implementation
     layer = ggml_permute(ctx0,
-                ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, layer, 1, 2, 0, 3))),
+                ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, layer, 1, 2, 0, 3)), eps),
                 2, 0, 1, 3);
 
     layer = ggml_add(ctx0,
@@ -1188,7 +1217,7 @@ bool sam_encode_image(
         // norm
         // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L168
         {
-            cur = ggml_norm(ctx0, inpL);
+            cur = ggml_norm(ctx0, inpL, hparams.eps);
 
             // cur = ln_0_w*cur + ln_0_b
             cur = ggml_add(ctx0,
@@ -1306,7 +1335,7 @@ bool sam_encode_image(
         {
             // norm
             {
-                cur = ggml_norm(ctx0, inpFF);
+                cur = ggml_norm(ctx0, inpFF, hparams.eps);
 
                 // cur = mlp_ln_w*cur + mlp_ln_b
                 cur = ggml_add(ctx0,
@@ -1347,11 +1376,11 @@ bool sam_encode_image(
 
     cur = ggml_conv_2d_sk_p0(ctx0, enc.neck_conv_0, cur);
 
-    cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_0_w, enc.neck_norm_0_b);
+    cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_0_w, enc.neck_norm_0_b, hparams.eps);
 
     cur = ggml_conv_2d_s1_ph(ctx0, enc.neck_conv_1, cur);
 
-    cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_1_w, enc.neck_norm_1_b);
+    cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_1_w, enc.neck_norm_1_b, hparams.eps);
 
     // TODO: avoid copy
     cur = ggml_cpy(ctx0, cur, state.embd_img);
@@ -1605,21 +1634,6 @@ bool sam_decode_mask(
     struct ggml_context * ctx0 = ggml_init(params);
     struct ggml_cgraph gf = {};
 
-    // auto print_t_f32 = [&](const char* title, struct ggml_tensor * t) {
-    //     printf("%s\n", title);
-    //     float * data = (float *)t->data;
-    //     printf("dims: %jd %jd %jd %jd f32\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]);
-    //     printf("First 10 elements:\n");
-    //     for (int i = 0; i < std::min((int) t->ne[0], 10); i++) {
-    //         printf("%f ", data[i]);
-    //     }
-    //     printf("\n");
-    //     double sum = 0.0;
-    //     for (int i = 0; i < ggml_nelements(t); i++) {
-    //         sum += data[i];
-    //     }
-    //     printf("sum:  %f\n\n", sum);
-    // };
     // print_t_f32("embd_img", state.embd_img);
     // print_t_f32("embd_prompt_dense", state.embd_prompt_dense);
     // print_t_f32("embd_prompt_sparse", state.embd_prompt_sparse);
@@ -1713,7 +1727,7 @@ bool sam_decode_mask(
                 queries = ggml_add(ctx0, queries, self_attn);
             }
 
-            queries = ggml_norm(ctx0, queries);
+            queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer);
             queries = ggml_add(ctx0,
                     ggml_mul(ctx0,
                         ggml_repeat(ctx0, tfm_layer.norm1_w, queries),
@@ -1728,7 +1742,7 @@ bool sam_decode_mask(
             struct ggml_tensor * cross_attn_token_to_img = sam_decode_mask_transformer_attn(tfm_layer.cross_attn_token_to_img, q_1, k_1, keys, ctx0, model);
 
             queries = ggml_add(ctx0, queries, cross_attn_token_to_img);
-            queries = ggml_norm(ctx0, queries);
+            queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer);
             queries = ggml_add(ctx0,
                     ggml_mul(ctx0,
                         ggml_repeat(ctx0, tfm_layer.norm2_w, queries),
@@ -1756,7 +1770,7 @@ bool sam_decode_mask(
                     mlp_out);
 
             queries = ggml_add(ctx0, queries, mlp_out);
-            queries = ggml_norm(ctx0, queries);
+            queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer);
             queries = ggml_add(ctx0,
                     ggml_mul(ctx0,
                         ggml_repeat(ctx0, tfm_layer.norm3_w, queries),
@@ -1770,7 +1784,7 @@ bool sam_decode_mask(
 
             struct ggml_tensor * cross_attn_img_to_token = sam_decode_mask_transformer_attn(tfm_layer.cross_attn_img_to_token, k_2, q_2, queries, ctx0, model);
             keys = ggml_add(ctx0, keys, cross_attn_img_to_token);
-            keys = ggml_norm(ctx0, keys);
+            keys = ggml_norm(ctx0, keys, hparams.eps_decoder_transformer);
             keys = ggml_add(ctx0,
                     ggml_mul(ctx0,
                         ggml_repeat(ctx0, tfm_layer.norm4_w, keys),
@@ -1786,7 +1800,7 @@ bool sam_decode_mask(
         struct ggml_tensor * final_attn_token_to_img = sam_decode_mask_transformer_attn(dec.transformer_final_attn_token_to_img, q, k, keys, ctx0, model);
 
         queries = ggml_add(ctx0, queries, final_attn_token_to_img);
-        queries = ggml_norm(ctx0, queries);
+        queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer);
         queries = ggml_add(ctx0,
                 ggml_mul(ctx0,
                     ggml_repeat(ctx0, dec.transformer_norm_final_w, queries),
@@ -1807,16 +1821,20 @@ bool sam_decode_mask(
     {
         // ConvTranspose2d
         keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_0_w, keys, 2);
-        keys = ggml_add(ctx0, ggml_repeat(ctx0, dec.output_upscaling_0_b, keys), keys);
-        keys = sam_layer_norm_2d(ctx0, keys, n_img_embd, dec.output_upscaling_1_w, dec.output_upscaling_1_b);
+        keys = ggml_add(ctx0, keys, ggml_repeat(ctx0,
+                                     ggml_reshape_3d(ctx0, dec.output_upscaling_0_b, 1, 1, dec.output_upscaling_0_b->ne[0]),
+                                     keys));
+
+        keys = sam_layer_norm_2d(ctx0, keys, n_img_embd, dec.output_upscaling_1_w, dec.output_upscaling_1_b, hparams.eps);
 
         // GELU activation
         keys = ggml_gelu(ctx0, keys);
 
         // ConvTranspose2d
         keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_3_w, keys, 2);
-        keys = ggml_add(ctx0, ggml_repeat(ctx0, dec.output_upscaling_3_b, keys), keys);
-
+        keys = ggml_add(ctx0, ggml_repeat(ctx0,
+                                ggml_reshape_3d(ctx0, dec.output_upscaling_3_b, 1, 1, dec.output_upscaling_3_b->ne[0]),
+                                keys), keys);
         // GELU activation
         keys = ggml_gelu(ctx0, keys);
         upscaled_embedding = ggml_reshape_3d(ctx0, keys, keys->ne[0]*keys->ne[1], keys->ne[2], keys->ne[3]);
index 29a8d98b093518f75a91024992e66c795b982804..56576a6630cb167dd461fbb56283ba0cdedc5d0f 100644 (file)
@@ -25,6 +25,7 @@ struct starcoder_hparams {
     int32_t n_head  = 16;
     int32_t n_layer = 24;
     int32_t ftype   = 1;
+    float   eps     = 1e-5;
 };
 
 struct starcoder_layer {
@@ -487,7 +488,7 @@ bool starcoder_eval(
         // norm
         {
             // [ 768, N]
-            cur = ggml_norm(ctx0, inpL);
+            cur = ggml_norm(ctx0, inpL, hparams.eps);
 
             // cur = ln_1_g*cur + ln_1_b
             // [ 768, N]
@@ -636,7 +637,7 @@ bool starcoder_eval(
         {
             // norm
             {
-                cur = ggml_norm(ctx0, inpFF);
+                cur = ggml_norm(ctx0, inpFF, hparams.eps);
 
                 // cur = ln_2_g*cur + ln_2_b
                 // [ 768, N]
@@ -693,7 +694,7 @@ bool starcoder_eval(
     // norm
     {
         // [ 768, N]
-        inpL = ggml_norm(ctx0, inpL);
+        inpL = ggml_norm(ctx0, inpL, hparams.eps);
 
         // inpL = ln_f_g*inpL + ln_f_b
         // [ 768, N]
index fe8a6d9a67c8bb85beffa47f02a7e1f6be56dd84..b7d26f4765f3cc031fee78c41751438ec9672c59 100644 (file)
@@ -40,6 +40,7 @@ struct starcoder_hparams {
     int32_t n_head  = 16;
     int32_t n_layer = 24;
     int32_t ftype   = 1;
+    float   eps     = 1e-5;
 };
 
 struct starcoder_layer {
@@ -698,7 +699,7 @@ bool starcoder_eval(
         // norm
         {
             // [ 768, N]
-            cur = ggml_norm(ctx0, inpL);
+            cur = ggml_norm(ctx0, inpL, hparams.eps);
 
             // cur = ln_1_g*cur + ln_1_b
             // [ 768, N]
@@ -847,7 +848,7 @@ bool starcoder_eval(
         {
             // norm
             {
-                cur = ggml_norm(ctx0, inpFF);
+                cur = ggml_norm(ctx0, inpFF, hparams.eps);
 
                 // cur = ln_2_g*cur + ln_2_b
                 // [ 768, N]
@@ -904,7 +905,7 @@ bool starcoder_eval(
     // norm
     {
         // [ 768, N]
-        inpL = ggml_norm(ctx0, inpL);
+        inpL = ggml_norm(ctx0, inpL, hparams.eps);
 
         // inpL = ln_f_g*inpL + ln_f_b
         // [ 768, N]
index cad40426cf5f44829828b7cac4b94171ad979148..2d1a70b80d187c170e50a0937f4702090e43faf7 100644 (file)
@@ -440,6 +440,7 @@ struct whisper_hparams {
     int32_t n_text_layer  = 4;
     int32_t n_mels        = 80;
     int32_t ftype         = 1;
+    float   eps           = 1e-5;
 };
 
 // audio encoding layer
@@ -1555,7 +1556,7 @@ static bool whisper_encode_internal(
             {
                 wstate.use_buf(ctx0, 0);
 
-                cur = ggml_norm(ctx0, inpL);
+                cur = ggml_norm(ctx0, inpL, hparams.eps);
 
                 // cur = ln_0_w*cur + ln_0_b
                 cur = ggml_add(ctx0,
@@ -1702,7 +1703,7 @@ static bool whisper_encode_internal(
                 {
                     wstate.use_buf(ctx0, 0);
 
-                    cur = ggml_norm(ctx0, inpFF);
+                    cur = ggml_norm(ctx0, inpFF, hparams.eps);
 
                     wstate.use_buf(ctx0, 1);
 
@@ -1765,7 +1766,7 @@ static bool whisper_encode_internal(
         {
             wstate.use_buf(ctx0, 0);
 
-            cur = ggml_norm(ctx0, cur);
+            cur = ggml_norm(ctx0, cur, hparams.eps);
 
             wstate.use_buf(ctx0, 1);
 
@@ -1966,7 +1967,7 @@ static bool whisper_decode_internal(
         {
             wstate.use_buf(ctx0, 0);
 
-            cur = ggml_norm(ctx0, inpL);
+            cur = ggml_norm(ctx0, inpL, hparams.eps);
 
             // cur = ln_0_w*cur + ln_0_b
             cur = ggml_add(ctx0,
@@ -2093,7 +2094,7 @@ static bool whisper_decode_internal(
         {
             wstate.use_buf(ctx0, 0);
 
-            cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
+            cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here
 
             // cur = ln_0_w*cur + ln_0_b
             cur = ggml_add(ctx0,
@@ -2203,7 +2204,7 @@ static bool whisper_decode_internal(
             {
                 wstate.use_buf(ctx0, 0);
 
-                cur = ggml_norm(ctx0, inpFF);
+                cur = ggml_norm(ctx0, inpFF, hparams.eps);
 
                 wstate.use_buf(ctx0, 1);
 
@@ -2258,7 +2259,7 @@ static bool whisper_decode_internal(
     {
         wstate.use_buf(ctx0, 0);
 
-        cur = ggml_norm(ctx0, cur);
+        cur = ggml_norm(ctx0, cur, hparams.eps);
 
         wstate.use_buf(ctx0, 1);
 
index 1baa6cea133794744969f79d8e79d5067bc4a510..f83fb5c26857be9f8faa44d0f97c8640518013e6 100644 (file)
@@ -917,14 +917,15 @@ extern "C" {
             struct ggml_tensor  * b);
 
     // normalize along rows
-    // TODO: eps is hardcoded to 1e-5 for now
     GGML_API struct ggml_tensor * ggml_norm(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a);
+            struct ggml_tensor  * a,
+            float                 eps);
 
     GGML_API struct ggml_tensor * ggml_norm_inplace(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a);
+            struct ggml_tensor  * a,
+            float                 eps);
 
     GGML_API struct ggml_tensor * ggml_rms_norm(
             struct ggml_context * ctx,
index f06f9a3c1d97b97c700631fe810f62fec280310a..6810a20f318cb58c71e2c5b81a0ef54b3817fc62 100644 (file)
@@ -271,7 +271,7 @@ struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment)
         /*.parse_seq     = */ {0},
         /*.has_parse_seq = */ false,
 #ifdef GGML_ALLOCATOR_DEBUG
-        /*.allocated_tensors = */ {0},
+        /*.allocated_tensors = */ {0},
 #endif
     };
 
@@ -300,7 +300,7 @@ struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
         /*.parse_seq     = */ {0},
         /*.has_parse_seq = */ false,
 #ifdef GGML_ALLOCATOR_DEBUG
-        /*.allocated_tensors = */ {0},
+        /*.allocated_tensors = */ {0},
 #endif
     };
 
@@ -547,7 +547,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
                         struct ggml_tensor * view_src = get_view_source(parent);
                         struct hash_node * view_src_hn = hash_get(ht, view_src);
                         view_src_hn->n_views -= 1;
-                        AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src->n_children, view_src->n_views);
+                        AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
                         if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
                             ggml_allocator_free_tensor(alloc, view_src);
                         }
index 5922c330bd2eee4d81cdea5f03f236fb2dc993c5..97232219be5f59c6f00c3c69beb31caf833fd357 100644 (file)
@@ -5783,6 +5783,7 @@ struct ggml_tensor * ggml_silu_back(
 static struct ggml_tensor * ggml_norm_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
+        float eps,
         bool inplace) {
     bool is_node = false;
 
@@ -5793,7 +5794,7 @@ static struct ggml_tensor * ggml_norm_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    // TODO: maybe store epsilon here?
+    ggml_set_op_params(result, &eps, sizeof(eps));
 
     result->op   = GGML_OP_NORM;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5804,14 +5805,16 @@ static struct ggml_tensor * ggml_norm_impl(
 
 struct ggml_tensor * ggml_norm(
         struct ggml_context * ctx,
-        struct ggml_tensor  * a) {
-    return ggml_norm_impl(ctx, a, false);
+        struct ggml_tensor  * a,
+        float eps) {
+    return ggml_norm_impl(ctx, a, eps, false);
 }
 
 struct ggml_tensor * ggml_norm_inplace(
         struct ggml_context * ctx,
-        struct ggml_tensor  * a) {
-    return ggml_norm_impl(ctx, a, true);
+        struct ggml_tensor  * a,
+        float eps) {
+    return ggml_norm_impl(ctx, a, eps, true);
 }
 
 // ggml_rms_norm
@@ -7092,11 +7095,13 @@ struct ggml_tensor * ggml_conv_transpose_2d_p0(
     };
 
     struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    ggml_set_op_params_i32(result, 0, stride);
+
     result->op = GGML_OP_CONV_TRANSPOSE_2D;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
-    result->src[2] = ggml_new_i32(ctx, stride);
 
     return result;
 }
@@ -10613,7 +10618,8 @@ static void ggml_compute_forward_norm_f32(
 
     GGML_TENSOR_UNARY_OP_LOCALS;
 
-    const float eps = 1e-5f; // TODO: make this a parameter
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
 
     // TODO: optimize
     for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -13491,7 +13497,6 @@ static void ggml_compute_forward_conv_transpose_2d(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        const struct ggml_tensor * opt0,
               struct ggml_tensor * dst) {
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -13551,7 +13556,7 @@ static void ggml_compute_forward_conv_transpose_2d(
         return;
     }
 
-    const int32_t stride = ((const int32_t*)(opt0->data))[0];
+    const int32_t stride = ggml_get_op_params_i32(dst, 0);
 
     // total patches in dst
     const int np = ne2;
@@ -13564,7 +13569,7 @@ static void ggml_compute_forward_conv_transpose_2d(
     const int ip1 = MIN(ip0 + dp, np);
 
     ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
-    ggml_fp16_t * const wdata_src = (ggml_fp16_t *) params->wdata + nk;
+    ggml_fp16_t * const wdata_src = wdata + nk;
 
     for (int i2 = ip0; i2 < ip1; i2++) { // Cout
         float * dst_data = (float *)((char *) dst->data + i2*nb2);
@@ -13576,9 +13581,8 @@ static void ggml_compute_forward_conv_transpose_2d(
                     for (int i00 = 0; i00 < ne00; i00++) {
                         float v = 0;
                         ggml_vec_dot_f16(ne03, &v,
-                                (ggml_fp16_t *) wdata_src + i1n,
-                                (ggml_fp16_t *) wdata_kernel + i01*ne00*ne03 + i00*ne03);
-
+                                wdata_src + i1n,
+                                wdata_kernel + i01*ne00*ne03 + i00*ne03);
                         dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
                     }
                 }
@@ -15725,7 +15729,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_CONV_TRANSPOSE_2D:
             {
-                ggml_compute_forward_conv_transpose_2d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
+                ggml_compute_forward_conv_transpose_2d(params, tensor->src[0], tensor->src[1], tensor);
             } break;
         case GGML_OP_POOL_1D:
             {