From: Yavor Ivanov Date: Mon, 28 Aug 2023 08:33:55 +0000 (+0300) Subject: ggml : sync with sam.cpp (add SAM Vit H & L model support and fix SAM's output) ... X-Git-Tag: upstream/0.0.1642~1264 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=420bb1a94cf470d1acec124fbf87bca4d44ab42d;p=pkg%2Fggml%2Fsources%2Fggml ggml : sync with sam.cpp (add SAM Vit H & L model support and fix SAM's output) (#476) * Add support for Vit H and Vit L SAM model checkpoints * Add "eps" argument to ggml_norm and fix all examples * Fix bias addition for ConvTranspose2D layers in SAM example * Fix build when GGML_ALLOCATOR_DEBUG is enabled * Use op params for the stride in CONV_TRANSPOSE_2D Needed in order for the operation to work with ggml-alloc as the previous implementation used ggml_new_i32, which uses strach buffers We should remove new_i32 and new_f32 I think. new_f32 is used in a lot of places. --- diff --git a/examples/dolly-v2/main.cpp b/examples/dolly-v2/main.cpp index aa06a6d7..a09cad61 100644 --- a/examples/dolly-v2/main.cpp +++ b/examples/dolly-v2/main.cpp @@ -39,6 +39,7 @@ struct dollyv2_hparams { int32_t n_rot = 20; // rotary_pct[25%] * (n_embd / n_head) int32_t par_res = 1; // 1 = true, 0 = false int32_t ftype = GGML_FTYPE_MOSTLY_F16; + float eps = 1e-5; }; const std::string INSTRUCTION_KEY = "### Instruction:"; @@ -412,10 +413,11 @@ bool dollyv2_model_load(const std::string & fname, dollyv2_model & model, gpt_vo // feed-forward network ggml_tensor * gpt_neox_ff( - const dollyv2_layer &layer, - ggml_context * ctx0, - ggml_tensor * inp) { - ggml_tensor * cur = ggml_norm(ctx0, inp); + const dollyv2_layer & layer, + ggml_context * ctx0, + ggml_tensor * inp, + float eps) { + ggml_tensor * cur = ggml_norm(ctx0, inp, eps); cur = ggml_add(ctx0, ggml_mul(ctx0, @@ -509,7 +511,7 @@ bool dollyv2_eval( // self-attention { { - cur = ggml_norm(ctx0, inpL); + cur = ggml_norm(ctx0, inpL, hparams.eps); cur = ggml_add(ctx0, ggml_mul(ctx0, @@ -612,7 +614,7 @@ bool dollyv2_eval( if (hparams.par_res == 0) { struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpL); - cur = gpt_neox_ff(model.layers[il], ctx0, inpFF); + cur = gpt_neox_ff(model.layers[il], ctx0, inpFF, hparams.eps); // input for next layer inpL = ggml_add(ctx0, cur, inpFF); @@ -621,7 +623,7 @@ bool dollyv2_eval( // this is independent of the self-attention result, so it could be done in parallel to the self-attention // note here we pass inpL instead of cur - cur = gpt_neox_ff(model.layers[il], ctx0, inpL); + cur = gpt_neox_ff(model.layers[il], ctx0, inpL, hparams.eps); // layer input + FF cur = ggml_add(ctx0, cur, inpFF); @@ -634,7 +636,7 @@ bool dollyv2_eval( // norm { - inpL = ggml_norm(ctx0, inpL); + inpL = ggml_norm(ctx0, inpL, hparams.eps); // inpL = ln_f_g*inpL + ln_f_b inpL = ggml_add(ctx0, diff --git a/examples/gpt-2/main.cpp b/examples/gpt-2/main.cpp index 87917e3d..ed405002 100644 --- a/examples/gpt-2/main.cpp +++ b/examples/gpt-2/main.cpp @@ -25,6 +25,7 @@ struct gpt2_hparams { int32_t n_head = 12; int32_t n_layer = 12; int32_t ftype = 1; + float eps = 1e-5; }; struct gpt2_layer { @@ -444,7 +445,7 @@ struct ggml_cgraph * gpt2_graph( // norm { // [ 768, N] - cur = ggml_norm(ctx0, inpL); + cur = ggml_norm(ctx0, inpL, hparams.eps); // cur = ln_1_g*cur + ln_1_b // [ 768, N] @@ -590,7 +591,7 @@ struct ggml_cgraph * gpt2_graph( { // norm { - cur = ggml_norm(ctx0, inpFF); + cur = ggml_norm(ctx0, inpFF, hparams.eps); // cur = ln_2_g*cur + ln_2_b // [ 768, N] @@ -645,7 +646,7 @@ struct ggml_cgraph * gpt2_graph( // norm { // [ 768, N] - inpL = ggml_norm(ctx0, inpL); + inpL = ggml_norm(ctx0, inpL, hparams.eps); // inpL = ln_f_g*inpL + ln_f_b // [ 768, N] diff --git a/examples/gpt-j/main.cpp b/examples/gpt-j/main.cpp index a0c1eea4..b23ad3d2 100644 --- a/examples/gpt-j/main.cpp +++ b/examples/gpt-j/main.cpp @@ -26,6 +26,7 @@ struct gptj_hparams { int32_t n_layer = 28; int32_t n_rot = 64; int32_t ftype = 1; + float eps = 1e-5; }; struct gptj_layer { @@ -437,7 +438,7 @@ bool gptj_eval( // norm { - cur = ggml_norm(ctx0, inpL); + cur = ggml_norm(ctx0, inpL, hparams.eps); // cur = ln_1_g*cur + ln_1_b cur = ggml_add(ctx0, @@ -559,7 +560,7 @@ bool gptj_eval( // norm { - inpL = ggml_norm(ctx0, inpL); + inpL = ggml_norm(ctx0, inpL, hparams.eps); // inpL = ln_f_g*inpL + ln_f_b inpL = ggml_add(ctx0, diff --git a/examples/gpt-neox/main.cpp b/examples/gpt-neox/main.cpp index af6129a1..80ee6643 100644 --- a/examples/gpt-neox/main.cpp +++ b/examples/gpt-neox/main.cpp @@ -27,6 +27,7 @@ struct gpt_neox_hparams { int32_t n_rot = 32; // rotary_pct * (n_embd / n_head) int32_t par_res = 1; // 1 = true, 0 = false int32_t ftype = 1; + float eps = 1e-5; }; struct gpt_neox_layer { @@ -384,10 +385,11 @@ bool gpt_neox_model_load(const std::string & fname, gpt_neox_model & model, gpt_ // feed-forward network ggml_tensor * gpt_neox_ff( - const gpt_neox_layer &layer, - ggml_context * ctx0, - ggml_tensor * inp) { - ggml_tensor * cur = ggml_norm(ctx0, inp); + const gpt_neox_layer & layer, + ggml_context * ctx0, + ggml_tensor * inp, + float eps) { + ggml_tensor * cur = ggml_norm(ctx0, inp, eps); cur = ggml_add(ctx0, ggml_mul(ctx0, @@ -491,7 +493,7 @@ bool gpt_neox_eval( // self-attention { { - cur = ggml_norm(ctx0, inpL); + cur = ggml_norm(ctx0, inpL, hparams.eps); cur = ggml_add(ctx0, ggml_mul(ctx0, @@ -596,7 +598,7 @@ bool gpt_neox_eval( if (hparams.par_res == 0) { struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpL); - cur = gpt_neox_ff(model.layers[il], ctx0, inpFF); + cur = gpt_neox_ff(model.layers[il], ctx0, inpFF, hparams.eps); // input for next layer inpL = ggml_add(ctx0, cur, inpFF); @@ -605,7 +607,7 @@ bool gpt_neox_eval( // this is independent of the self-attention result, so it could be done in parallel to the self-attention // note here we pass inpL instead of cur - cur = gpt_neox_ff(model.layers[il], ctx0, inpL); + cur = gpt_neox_ff(model.layers[il], ctx0, inpL, hparams.eps); // layer input + FF cur = ggml_add(ctx0, cur, inpFF); @@ -619,7 +621,7 @@ bool gpt_neox_eval( // norm { - inpL = ggml_norm(ctx0, inpL); + inpL = ggml_norm(ctx0, inpL, hparams.eps); // inpL = ln_f_g*inpL + ln_f_b inpL = ggml_add(ctx0, diff --git a/examples/mpt/main.cpp b/examples/mpt/main.cpp index fb8dbb6d..2fda67cc 100644 --- a/examples/mpt/main.cpp +++ b/examples/mpt/main.cpp @@ -465,6 +465,7 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past, const int n_head = hparams.n_heads; const int n_vocab = hparams.n_vocab; const int n_ctx = hparams.n_ctx; + const float eps = 1e-5; static size_t buf_size = 256u * 1024 * 1024; static void * buf = malloc(buf_size); @@ -513,7 +514,7 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past, // a = self.ln_1(x) { - cur = ggml_norm(ctx0, inpL); + cur = ggml_norm(ctx0, inpL, eps); cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].norm_1_weight, cur), cur); } @@ -609,7 +610,7 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past, // m = self.ln_2(x) { - cur = ggml_norm(ctx0, inpL); + cur = ggml_norm(ctx0, inpL, eps); cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].norm_2_weight, cur), cur); } @@ -635,7 +636,7 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past, // norm { - inpL = ggml_norm(ctx0, inpL); + inpL = ggml_norm(ctx0, inpL, eps); // inpL = ln_f_g*inpL inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.norm_f_weight, inpL), inpL); } diff --git a/examples/replit/main.cpp b/examples/replit/main.cpp index 967abe61..3fb664d8 100644 --- a/examples/replit/main.cpp +++ b/examples/replit/main.cpp @@ -450,6 +450,7 @@ bool replit_eval(const replit_model & model, const int n_threads, const int n_pa const int n_head = hparams.n_heads; const int n_vocab = hparams.n_vocab; const int n_ctx = hparams.max_seq_len; + const float eps = 1e-5; static size_t buf_size = 256u * 1024 * 1024; static void * buf = malloc(buf_size); @@ -488,7 +489,7 @@ bool replit_eval(const replit_model & model, const int n_threads, const int n_pa // a = self.ln_1(x) { - cur = ggml_norm(ctx0, inpL); + cur = ggml_norm(ctx0, inpL, eps); cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].norm_1_weight, cur), cur); } @@ -577,7 +578,7 @@ bool replit_eval(const replit_model & model, const int n_threads, const int n_pa // m = self.ln_2(x) { - cur = ggml_norm(ctx0, inpL); + cur = ggml_norm(ctx0, inpL, eps); cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].norm_2_weight, cur), cur); } @@ -601,7 +602,7 @@ bool replit_eval(const replit_model & model, const int n_threads, const int n_pa // norm { - inpL = ggml_norm(ctx0, inpL); + inpL = ggml_norm(ctx0, inpL, eps); // inpL = ln_f_g*inpL inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.norm_f_weight, inpL), inpL); } diff --git a/examples/sam/convert-pth-to-ggml.py b/examples/sam/convert-pth-to-ggml.py index 91149c72..5f97f0fb 100644 --- a/examples/sam/convert-pth-to-ggml.py +++ b/examples/sam/convert-pth-to-ggml.py @@ -36,17 +36,32 @@ if ftype < 0 or ftype > 1: fname_out = fname_out.replace(".bin", "-" + ftype_str[ftype] + ".bin") +# Default params are set to sam_vit_b checkpoint +n_enc_state = 768 +n_enc_layers = 12 +n_enc_heads = 12 +n_enc_out_chans = 256 +n_pt_embd = 4 + model = torch.load(fname_model, map_location="cpu") +for k, v in model.items(): + print(k, v.shape) + if k == "image_encoder.blocks.0.norm1.weight": + n_enc_state = v.shape[0] -# TODO: determine based on model data -# TODO: add decoder / prompt encoder if needed -hparams = { - "n_enc_state": 768, - "n_enc_layers": 12, - "n_enc_heads": 12, - "n_enc_out_chans": 256, +if n_enc_state == 1024: # sam_vit_l + n_enc_layers = 24 + n_enc_heads = 16 +elif n_enc_state == 1280: # sam_vit_h + n_enc_layers = 32 + n_enc_heads = 16 - "n_pt_embd": 4, +hparams = { + "n_enc_state": n_enc_state, + "n_enc_layers": n_enc_layers, + "n_enc_heads": n_enc_heads, + "n_enc_out_chans": n_enc_out_chans, + "n_pt_embd": n_pt_embd, } print(hparams) @@ -79,7 +94,7 @@ for k, v in model.items(): #data = tf.train.load_variable(dir_model, name).squeeze() #data = v.numpy().squeeze() data = v.numpy() - n_dims = len(data.shape); + n_dims = len(data.shape) # for efficiency - transpose some matrices # "model/h.*/attn/c_attn/w" @@ -113,7 +128,7 @@ for k, v in model.items(): # keep it in F32 since the data is small if name == "image_encoder.patch_embed.proj.bias": data = data.reshape(1, data.shape[0], 1, 1) - n_dims = len(data.shape); + n_dims = len(data.shape) dshape = data.shape print(" New shape: ", dshape) @@ -123,7 +138,7 @@ for k, v in model.items(): fout.write(struct.pack("iii", n_dims, len(str), ftype_cur)) for i in range(n_dims): fout.write(struct.pack("i", dshape[n_dims - 1 - i])) - fout.write(str); + fout.write(str) # data data.tofile(fout) diff --git a/examples/sam/main.cpp b/examples/sam/main.cpp index 9f6cd838..99b039b5 100644 --- a/examples/sam/main.cpp +++ b/examples/sam/main.cpp @@ -16,6 +16,32 @@ #include #include +// void print_t_f32(const char* title, struct ggml_tensor * t, int n = 10) { +// printf("%s\n", title); +// float * data = (float *)t->data; +// printf("dims: %jd %jd %jd %jd f32\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]); +// printf("First & Last %d elements:\n", n); +// for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) { +// printf("%.5f ", data[i]); +// if (i != 0 && i % t->ne[0] == 0) { +// printf("\n"); +// } +// } +// printf("\n"); +// for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) { +// printf("%.5f ", data[ggml_nelements(t) - n + i]); +// if ((ggml_nelements(t) - n + i) % t->ne[0] == 0) { +// printf("\n"); +// } +// } +// printf("\n"); +// double sum = 0.0; +// for (int i = 0; i < ggml_nelements(t); i++) { +// sum += data[i]; +// } +// printf("sum: %f\n\n", sum); +// } + // default hparams (ViT-B SAM) struct sam_hparams { int32_t n_enc_state = 768; @@ -29,6 +55,8 @@ struct sam_hparams { float iou_threshold = 0.88f; float stability_score_threshold = 0.95f; float stability_score_offset = 1.0f; + float eps = 1e-6f; + float eps_decoder_transformer = 1e-5f; int32_t n_enc_head_dim() const { return n_enc_state / n_enc_head; } int32_t n_img_size() const { return 1024; } @@ -1083,12 +1111,13 @@ struct ggml_tensor* sam_layer_norm_2d( struct ggml_tensor * layer, int n_channels, struct ggml_tensor * w, - struct ggml_tensor * b) { + struct ggml_tensor * b, + float eps) { // LayerNorm2d // normalize along channel dimmension // TODO: better implementation layer = ggml_permute(ctx0, - ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, layer, 1, 2, 0, 3))), + ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, layer, 1, 2, 0, 3)), eps), 2, 0, 1, 3); layer = ggml_add(ctx0, @@ -1188,7 +1217,7 @@ bool sam_encode_image( // norm // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L168 { - cur = ggml_norm(ctx0, inpL); + cur = ggml_norm(ctx0, inpL, hparams.eps); // cur = ln_0_w*cur + ln_0_b cur = ggml_add(ctx0, @@ -1306,7 +1335,7 @@ bool sam_encode_image( { // norm { - cur = ggml_norm(ctx0, inpFF); + cur = ggml_norm(ctx0, inpFF, hparams.eps); // cur = mlp_ln_w*cur + mlp_ln_b cur = ggml_add(ctx0, @@ -1347,11 +1376,11 @@ bool sam_encode_image( cur = ggml_conv_2d_sk_p0(ctx0, enc.neck_conv_0, cur); - cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_0_w, enc.neck_norm_0_b); + cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_0_w, enc.neck_norm_0_b, hparams.eps); cur = ggml_conv_2d_s1_ph(ctx0, enc.neck_conv_1, cur); - cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_1_w, enc.neck_norm_1_b); + cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_1_w, enc.neck_norm_1_b, hparams.eps); // TODO: avoid copy cur = ggml_cpy(ctx0, cur, state.embd_img); @@ -1605,21 +1634,6 @@ bool sam_decode_mask( struct ggml_context * ctx0 = ggml_init(params); struct ggml_cgraph gf = {}; - // auto print_t_f32 = [&](const char* title, struct ggml_tensor * t) { - // printf("%s\n", title); - // float * data = (float *)t->data; - // printf("dims: %jd %jd %jd %jd f32\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]); - // printf("First 10 elements:\n"); - // for (int i = 0; i < std::min((int) t->ne[0], 10); i++) { - // printf("%f ", data[i]); - // } - // printf("\n"); - // double sum = 0.0; - // for (int i = 0; i < ggml_nelements(t); i++) { - // sum += data[i]; - // } - // printf("sum: %f\n\n", sum); - // }; // print_t_f32("embd_img", state.embd_img); // print_t_f32("embd_prompt_dense", state.embd_prompt_dense); // print_t_f32("embd_prompt_sparse", state.embd_prompt_sparse); @@ -1713,7 +1727,7 @@ bool sam_decode_mask( queries = ggml_add(ctx0, queries, self_attn); } - queries = ggml_norm(ctx0, queries); + queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer); queries = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, tfm_layer.norm1_w, queries), @@ -1728,7 +1742,7 @@ bool sam_decode_mask( struct ggml_tensor * cross_attn_token_to_img = sam_decode_mask_transformer_attn(tfm_layer.cross_attn_token_to_img, q_1, k_1, keys, ctx0, model); queries = ggml_add(ctx0, queries, cross_attn_token_to_img); - queries = ggml_norm(ctx0, queries); + queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer); queries = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, tfm_layer.norm2_w, queries), @@ -1756,7 +1770,7 @@ bool sam_decode_mask( mlp_out); queries = ggml_add(ctx0, queries, mlp_out); - queries = ggml_norm(ctx0, queries); + queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer); queries = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, tfm_layer.norm3_w, queries), @@ -1770,7 +1784,7 @@ bool sam_decode_mask( struct ggml_tensor * cross_attn_img_to_token = sam_decode_mask_transformer_attn(tfm_layer.cross_attn_img_to_token, k_2, q_2, queries, ctx0, model); keys = ggml_add(ctx0, keys, cross_attn_img_to_token); - keys = ggml_norm(ctx0, keys); + keys = ggml_norm(ctx0, keys, hparams.eps_decoder_transformer); keys = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, tfm_layer.norm4_w, keys), @@ -1786,7 +1800,7 @@ bool sam_decode_mask( struct ggml_tensor * final_attn_token_to_img = sam_decode_mask_transformer_attn(dec.transformer_final_attn_token_to_img, q, k, keys, ctx0, model); queries = ggml_add(ctx0, queries, final_attn_token_to_img); - queries = ggml_norm(ctx0, queries); + queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer); queries = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, dec.transformer_norm_final_w, queries), @@ -1807,16 +1821,20 @@ bool sam_decode_mask( { // ConvTranspose2d keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_0_w, keys, 2); - keys = ggml_add(ctx0, ggml_repeat(ctx0, dec.output_upscaling_0_b, keys), keys); - keys = sam_layer_norm_2d(ctx0, keys, n_img_embd, dec.output_upscaling_1_w, dec.output_upscaling_1_b); + keys = ggml_add(ctx0, keys, ggml_repeat(ctx0, + ggml_reshape_3d(ctx0, dec.output_upscaling_0_b, 1, 1, dec.output_upscaling_0_b->ne[0]), + keys)); + + keys = sam_layer_norm_2d(ctx0, keys, n_img_embd, dec.output_upscaling_1_w, dec.output_upscaling_1_b, hparams.eps); // GELU activation keys = ggml_gelu(ctx0, keys); // ConvTranspose2d keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_3_w, keys, 2); - keys = ggml_add(ctx0, ggml_repeat(ctx0, dec.output_upscaling_3_b, keys), keys); - + keys = ggml_add(ctx0, ggml_repeat(ctx0, + ggml_reshape_3d(ctx0, dec.output_upscaling_3_b, 1, 1, dec.output_upscaling_3_b->ne[0]), + keys), keys); // GELU activation keys = ggml_gelu(ctx0, keys); upscaled_embedding = ggml_reshape_3d(ctx0, keys, keys->ne[0]*keys->ne[1], keys->ne[2], keys->ne[3]); diff --git a/examples/starcoder/main.cpp b/examples/starcoder/main.cpp index 29a8d98b..56576a66 100644 --- a/examples/starcoder/main.cpp +++ b/examples/starcoder/main.cpp @@ -25,6 +25,7 @@ struct starcoder_hparams { int32_t n_head = 16; int32_t n_layer = 24; int32_t ftype = 1; + float eps = 1e-5; }; struct starcoder_layer { @@ -487,7 +488,7 @@ bool starcoder_eval( // norm { // [ 768, N] - cur = ggml_norm(ctx0, inpL); + cur = ggml_norm(ctx0, inpL, hparams.eps); // cur = ln_1_g*cur + ln_1_b // [ 768, N] @@ -636,7 +637,7 @@ bool starcoder_eval( { // norm { - cur = ggml_norm(ctx0, inpFF); + cur = ggml_norm(ctx0, inpFF, hparams.eps); // cur = ln_2_g*cur + ln_2_b // [ 768, N] @@ -693,7 +694,7 @@ bool starcoder_eval( // norm { // [ 768, N] - inpL = ggml_norm(ctx0, inpL); + inpL = ggml_norm(ctx0, inpL, hparams.eps); // inpL = ln_f_g*inpL + ln_f_b // [ 768, N] diff --git a/examples/starcoder/starcoder-mmap.cpp b/examples/starcoder/starcoder-mmap.cpp index fe8a6d9a..b7d26f47 100644 --- a/examples/starcoder/starcoder-mmap.cpp +++ b/examples/starcoder/starcoder-mmap.cpp @@ -40,6 +40,7 @@ struct starcoder_hparams { int32_t n_head = 16; int32_t n_layer = 24; int32_t ftype = 1; + float eps = 1e-5; }; struct starcoder_layer { @@ -698,7 +699,7 @@ bool starcoder_eval( // norm { // [ 768, N] - cur = ggml_norm(ctx0, inpL); + cur = ggml_norm(ctx0, inpL, hparams.eps); // cur = ln_1_g*cur + ln_1_b // [ 768, N] @@ -847,7 +848,7 @@ bool starcoder_eval( { // norm { - cur = ggml_norm(ctx0, inpFF); + cur = ggml_norm(ctx0, inpFF, hparams.eps); // cur = ln_2_g*cur + ln_2_b // [ 768, N] @@ -904,7 +905,7 @@ bool starcoder_eval( // norm { // [ 768, N] - inpL = ggml_norm(ctx0, inpL); + inpL = ggml_norm(ctx0, inpL, hparams.eps); // inpL = ln_f_g*inpL + ln_f_b // [ 768, N] diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index cad40426..2d1a70b8 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -440,6 +440,7 @@ struct whisper_hparams { int32_t n_text_layer = 4; int32_t n_mels = 80; int32_t ftype = 1; + float eps = 1e-5; }; // audio encoding layer @@ -1555,7 +1556,7 @@ static bool whisper_encode_internal( { wstate.use_buf(ctx0, 0); - cur = ggml_norm(ctx0, inpL); + cur = ggml_norm(ctx0, inpL, hparams.eps); // cur = ln_0_w*cur + ln_0_b cur = ggml_add(ctx0, @@ -1702,7 +1703,7 @@ static bool whisper_encode_internal( { wstate.use_buf(ctx0, 0); - cur = ggml_norm(ctx0, inpFF); + cur = ggml_norm(ctx0, inpFF, hparams.eps); wstate.use_buf(ctx0, 1); @@ -1765,7 +1766,7 @@ static bool whisper_encode_internal( { wstate.use_buf(ctx0, 0); - cur = ggml_norm(ctx0, cur); + cur = ggml_norm(ctx0, cur, hparams.eps); wstate.use_buf(ctx0, 1); @@ -1966,7 +1967,7 @@ static bool whisper_decode_internal( { wstate.use_buf(ctx0, 0); - cur = ggml_norm(ctx0, inpL); + cur = ggml_norm(ctx0, inpL, hparams.eps); // cur = ln_0_w*cur + ln_0_b cur = ggml_add(ctx0, @@ -2093,7 +2094,7 @@ static bool whisper_decode_internal( { wstate.use_buf(ctx0, 0); - cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here + cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here // cur = ln_0_w*cur + ln_0_b cur = ggml_add(ctx0, @@ -2203,7 +2204,7 @@ static bool whisper_decode_internal( { wstate.use_buf(ctx0, 0); - cur = ggml_norm(ctx0, inpFF); + cur = ggml_norm(ctx0, inpFF, hparams.eps); wstate.use_buf(ctx0, 1); @@ -2258,7 +2259,7 @@ static bool whisper_decode_internal( { wstate.use_buf(ctx0, 0); - cur = ggml_norm(ctx0, cur); + cur = ggml_norm(ctx0, cur, hparams.eps); wstate.use_buf(ctx0, 1); diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 1baa6cea..f83fb5c2 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -917,14 +917,15 @@ extern "C" { struct ggml_tensor * b); // normalize along rows - // TODO: eps is hardcoded to 1e-5 for now GGML_API struct ggml_tensor * ggml_norm( struct ggml_context * ctx, - struct ggml_tensor * a); + struct ggml_tensor * a, + float eps); GGML_API struct ggml_tensor * ggml_norm_inplace( struct ggml_context * ctx, - struct ggml_tensor * a); + struct ggml_tensor * a, + float eps); GGML_API struct ggml_tensor * ggml_rms_norm( struct ggml_context * ctx, diff --git a/src/ggml-alloc.c b/src/ggml-alloc.c index f06f9a3c..6810a20f 100644 --- a/src/ggml-alloc.c +++ b/src/ggml-alloc.c @@ -271,7 +271,7 @@ struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment) /*.parse_seq = */ {0}, /*.has_parse_seq = */ false, #ifdef GGML_ALLOCATOR_DEBUG - /*.allocated_tensors = */ = {0}, + /*.allocated_tensors = */ {0}, #endif }; @@ -300,7 +300,7 @@ struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) { /*.parse_seq = */ {0}, /*.has_parse_seq = */ false, #ifdef GGML_ALLOCATOR_DEBUG - /*.allocated_tensors = */ = {0}, + /*.allocated_tensors = */ {0}, #endif }; @@ -547,7 +547,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n( struct ggml_tensor * view_src = get_view_source(parent); struct hash_node * view_src_hn = hash_get(ht, view_src); view_src_hn->n_views -= 1; - AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src->n_children, view_src->n_views); + AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views); if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) { ggml_allocator_free_tensor(alloc, view_src); } diff --git a/src/ggml.c b/src/ggml.c index 5922c330..97232219 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -5783,6 +5783,7 @@ struct ggml_tensor * ggml_silu_back( static struct ggml_tensor * ggml_norm_impl( struct ggml_context * ctx, struct ggml_tensor * a, + float eps, bool inplace) { bool is_node = false; @@ -5793,7 +5794,7 @@ static struct ggml_tensor * ggml_norm_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - // TODO: maybe store epsilon here? + ggml_set_op_params(result, &eps, sizeof(eps)); result->op = GGML_OP_NORM; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -5804,14 +5805,16 @@ static struct ggml_tensor * ggml_norm_impl( struct ggml_tensor * ggml_norm( struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_norm_impl(ctx, a, false); + struct ggml_tensor * a, + float eps) { + return ggml_norm_impl(ctx, a, eps, false); } struct ggml_tensor * ggml_norm_inplace( struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_norm_impl(ctx, a, true); + struct ggml_tensor * a, + float eps) { + return ggml_norm_impl(ctx, a, eps, true); } // ggml_rms_norm @@ -7092,11 +7095,13 @@ struct ggml_tensor * ggml_conv_transpose_2d_p0( }; struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + ggml_set_op_params_i32(result, 0, stride); + result->op = GGML_OP_CONV_TRANSPOSE_2D; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; - result->src[2] = ggml_new_i32(ctx, stride); return result; } @@ -10613,7 +10618,8 @@ static void ggml_compute_forward_norm_f32( GGML_TENSOR_UNARY_OP_LOCALS; - const float eps = 1e-5f; // TODO: make this a parameter + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { @@ -13491,7 +13497,6 @@ static void ggml_compute_forward_conv_transpose_2d( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -13551,7 +13556,7 @@ static void ggml_compute_forward_conv_transpose_2d( return; } - const int32_t stride = ((const int32_t*)(opt0->data))[0]; + const int32_t stride = ggml_get_op_params_i32(dst, 0); // total patches in dst const int np = ne2; @@ -13564,7 +13569,7 @@ static void ggml_compute_forward_conv_transpose_2d( const int ip1 = MIN(ip0 + dp, np); ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - ggml_fp16_t * const wdata_src = (ggml_fp16_t *) params->wdata + nk; + ggml_fp16_t * const wdata_src = wdata + nk; for (int i2 = ip0; i2 < ip1; i2++) { // Cout float * dst_data = (float *)((char *) dst->data + i2*nb2); @@ -13576,9 +13581,8 @@ static void ggml_compute_forward_conv_transpose_2d( for (int i00 = 0; i00 < ne00; i00++) { float v = 0; ggml_vec_dot_f16(ne03, &v, - (ggml_fp16_t *) wdata_src + i1n, - (ggml_fp16_t *) wdata_kernel + i01*ne00*ne03 + i00*ne03); - + wdata_src + i1n, + wdata_kernel + i01*ne00*ne03 + i00*ne03); dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v; } } @@ -15725,7 +15729,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_CONV_TRANSPOSE_2D: { - ggml_compute_forward_conv_transpose_2d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + ggml_compute_forward_conv_transpose_2d(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_POOL_1D: {