From: Yavor Ivanov Date: Mon, 28 Aug 2023 12:40:23 +0000 (+0300) Subject: sam : use ggml-alloc (#490) X-Git-Tag: upstream/0.0.1642~1258 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=69bf842b39e6a53218eda35852a91d395d357e18;p=pkg%2Fggml%2Fsources%2Fggml sam : use ggml-alloc (#490) --- diff --git a/examples/sam/README.md b/examples/sam/README.md index fa4b993f..d8702f1d 100644 --- a/examples/sam/README.md +++ b/examples/sam/README.md @@ -8,15 +8,15 @@ The example currently supports only the [ViT-B SAM model checkpoint](https://hug ## Next steps -- [ ] Reduce memory usage by utilizing the new ggml-alloc +- [X] Reduce memory usage by utilizing the new ggml-alloc - [ ] Remove redundant graph nodes - [ ] Make inference faster -- [ ] Fix the difference in output masks compared to the PyTorch implementation +- [X] Fix the difference in output masks compared to the PyTorch implementation - [X] Filter masks based on stability score - [ ] Add support for user input - [ ] Support F16 for heavy F32 ops - [ ] Test quantization -- [ ] Support bigger model checkpoints +- [X] Support bigger model checkpoints - [ ] GPU support ## Quick start diff --git a/examples/sam/main.cpp b/examples/sam/main.cpp index f5715691..e8bfbb3d 100644 --- a/examples/sam/main.cpp +++ b/examples/sam/main.cpp @@ -1,9 +1,7 @@ #define _USE_MATH_DEFINES // for M_PI #include "ggml.h" - -#include "common.h" - +#include "ggml-alloc.h" #define STB_IMAGE_IMPLEMENTATION #include "stb_image.h" #define STB_IMAGE_WRITE_IMPLEMENTATION @@ -18,32 +16,7 @@ #include #include #include - -// void print_t_f32(const char* title, struct ggml_tensor * t, int n = 10) { -// printf("%s\n", title); -// float * data = (float *)t->data; -// printf("dims: %jd %jd %jd %jd f32\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]); -// printf("First & Last %d elements:\n", n); -// for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) { -// printf("%.5f ", data[i]); -// if (i != 0 && i % t->ne[0] == 0) { -// printf("\n"); -// } -// } -// printf("\n"); -// for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) { -// printf("%.5f ", data[ggml_nelements(t) - n + i]); -// if ((ggml_nelements(t) - n + i) % t->ne[0] == 0) { -// printf("\n"); -// } -// } -// printf("\n"); -// double sum = 0.0; -// for (int i = 0; i < ggml_nelements(t); i++) { -// sum += data[i]; -// } -// printf("sum: %f\n\n", sum); -// } +#include // default hparams (ViT-B SAM) struct sam_hparams { @@ -252,20 +225,37 @@ struct sam_decoder_mask { struct ggml_tensor * mask_tokens_w; }; + struct sam_state { struct ggml_tensor * embd_img; - struct ggml_tensor * embd_prompt_sparse; - struct ggml_tensor * embd_prompt_dense; - struct ggml_tensor * pe_img_dense; struct ggml_tensor * low_res_masks; struct ggml_tensor * iou_predictions; + //struct ggml_tensor * tmp_save = {}; + struct ggml_context * ctx; - std::vector buf; + // buffer for `ggml_graph_plan.work_data` + std::vector work_buffer; + // buffers to evaluate the model + std::vector buf_alloc_img_enc; + std::vector buf_compute_img_enc; + + std::vector buf_alloc_fast; + std::vector buf_compute_fast; + + struct ggml_allocr * allocr = {}; }; +// void save_tensor(sam_state& state, struct ggml_tensor * t, struct ggml_cgraph * gf) { +// if (!state.tmp_save) { +// state.tmp_save = ggml_new_tensor(state.ctx, t->type, t->n_dims, t->ne); +// } +// struct ggml_tensor * tmp0 = ggml_cpy(state.ctx, t, state.tmp_save); +// ggml_build_forward_expand(gf, tmp0); +// } + struct sam_model { sam_hparams hparams; @@ -300,7 +290,51 @@ struct sam_image_f32 { std::vector data; }; -void ggml_sam_sin(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) { +void print_t_f32(const char* title, struct ggml_tensor * t, int n = 10) { + printf("%s\n", title); + float * data = (float *)t->data; + printf("dims: %jd %jd %jd %jd f32\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]); + printf("First & Last %d elements:\n", n); + for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) { + printf("%.5f ", data[i]); + if (i != 0 && i % t->ne[0] == 0) { + printf("\n"); + } + } + printf("\n"); + for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) { + printf("%.5f ", data[ggml_nelements(t) - n + i]); + if ((ggml_nelements(t) - n + i) % t->ne[0] == 0) { + printf("\n"); + } + } + printf("\n"); + double sum = 0.0; + for (int i = 0; i < ggml_nelements(t); i++) { + sum += data[i]; + } + printf("sum: %f\n\n", sum); +} + +static void ggml_disconnect_node_from_graph(ggml_tensor * t) { + t->op = GGML_OP_NONE; + for (int i = 0; i < GGML_MAX_SRC; i++) { + t->src[i] = NULL; + } +} + +static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads) { + struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); + + if (plan.work_size > 0) { + buf.resize(plan.work_size); + plan.work_data = buf.data(); + } + + ggml_graph_compute(graph, &plan); +} + +static void ggml_sam_sin(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) { GGML_ASSERT(userdata == NULL); GGML_ASSERT(ggml_are_same_shape(dst, src)); GGML_ASSERT(ggml_is_contiguous(dst)); @@ -319,7 +353,7 @@ void ggml_sam_sin(struct ggml_tensor * dst , const struct ggml_tensor * src, int } } -void ggml_sam_cos(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) { +static void ggml_sam_cos(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) { GGML_ASSERT(userdata == NULL); GGML_ASSERT(ggml_are_same_shape(dst, src)); GGML_ASSERT(ggml_is_contiguous(dst)); @@ -1043,31 +1077,21 @@ bool sam_model_load(const std::string & fname, sam_model & model) { return true; } -bool sam_fill_dense_pe( - const sam_model & model, - sam_state & state, - int n_threads) { +struct ggml_tensor * sam_fill_dense_pe( + const sam_model & model, + struct ggml_context * ctx0, + struct ggml_cgraph * gf, + sam_state & state) { const auto & hparams = model.hparams; const auto & enc = model.enc_prompt; const int32_t n_img_embd = hparams.n_img_embd(); const float n_img_embd_inv = 1.0f / n_img_embd; - static size_t buf_size = 256u*1024*1024; - static void * buf = malloc(buf_size); - - struct ggml_init_params params = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ buf, - /*.no_alloc =*/ false, - }; - - struct ggml_context * ctx0 = ggml_init(params); - struct ggml_cgraph gf = {}; - struct ggml_tensor * xy_embed_stacked = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 2, n_img_embd, n_img_embd); + ggml_allocr_alloc(state.allocr, xy_embed_stacked); - { + if (!ggml_allocr_is_measure(state.allocr)) { float * data = (float *) ggml_get_data(xy_embed_stacked); for (int i = 0; i < n_img_embd; ++i) { const int row = 2*i*n_img_embd; @@ -1092,21 +1116,14 @@ bool sam_fill_dense_pe( cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, t_sin->ne[0] + t_cos->ne[0], cur->ne[1], cur->ne[2]); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, t_sin, ggml_view_3d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], t_sin->ne[2], cur->nb[1], cur->nb[2], 0))); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, t_cos, ggml_view_3d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], t_sin->ne[2], cur->nb[1], cur->nb[2], t_sin->nb[1]))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_sin, ggml_view_3d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], t_sin->ne[2], cur->nb[1], cur->nb[2], 0))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_cos, ggml_view_3d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], t_sin->ne[2], cur->nb[1], cur->nb[2], t_sin->nb[1]))); } - cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); - - // TODO: avoid copy - cur = ggml_cpy(ctx0, cur, state.pe_img_dense); + struct ggml_tensor * pe_img_dense = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); + ggml_build_forward_expand(gf, pe_img_dense); - // run the computation - ggml_set_name(cur, "check"); - ggml_build_forward_expand(&gf, cur); - ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); - - return true; + return pe_img_dense; } struct ggml_tensor* sam_layer_norm_2d( @@ -1132,11 +1149,10 @@ struct ggml_tensor* sam_layer_norm_2d( return layer; } -bool sam_encode_image( +struct ggml_cgraph * sam_encode_image( const sam_model & model, sam_state & state, - const sam_image_f32 & img, - int n_threads) { + const sam_image_f32 & img) { const auto & hparams = model.hparams; const auto & enc = model.enc_img; @@ -1146,32 +1162,21 @@ bool sam_encode_image( const int32_t n_enc_head = hparams.n_enc_head; const int32_t n_enc_head_dim = hparams.n_enc_head_dim(); const int32_t n_enc_out_chans = hparams.n_enc_out_chans; - const int32_t n_img_size = hparams.n_img_size(); const int32_t n_window_size = hparams.n_window_size(); - static size_t buf_size = 256u*1024*1024; - static void * buf = malloc(buf_size); - - // use 2 scratch buffers - // TODO: very hacky solution - reimplement in a more elegant way - static size_t scr0_size = 2048u*1024*1024; - static void * scr0 = malloc(scr0_size); - - static size_t scr1_size = 512u*1024*1024; - static void * scr1 = malloc(scr1_size); - - struct ggml_init_params params = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ buf, - /*.no_alloc =*/ false, + struct ggml_init_params ggml_params = { + /*.mem_size =*/ state.buf_compute_img_enc.size(), + /*.mem_buffer =*/ state.buf_compute_img_enc.data(), + /*.no_alloc =*/ true, // skip allocating as we use ggml_alloc to allocate exact memory requirements }; - struct ggml_context * ctx0 = ggml_init(params); - struct ggml_cgraph gf = {}; + struct ggml_context * ctx0 = ggml_init(ggml_params); + struct ggml_cgraph * gf = ggml_new_graph(ctx0); struct ggml_tensor * inp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_img_size, n_img_size, 3, 1); - { + ggml_allocr_alloc(state.allocr, inp); + if (!ggml_allocr_is_measure(state.allocr)) { float * data = (float *) ggml_get_data(inp); const int nx = img.nx; @@ -1215,8 +1220,6 @@ bool sam_encode_image( for (int il = 0; il < n_enc_layer; ++il) { const auto & layer = enc.layers[il]; - ggml_set_scratch(ctx0, { 0, scr0_size, scr0, }); - // norm // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L168 { @@ -1242,8 +1245,6 @@ bool sam_encode_image( const int64_t W = cur->ne[1]; const int64_t H = cur->ne[2]; - ggml_set_scratch(ctx0, { 0, scr1_size, scr1, }); - // self-attention { cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); @@ -1279,8 +1280,6 @@ bool sam_encode_image( V = ggml_cont (ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); // transposed V = ggml_reshape_3d(ctx0, V, W*H, n_enc_head_dim, B*n_enc_head); - ggml_set_scratch(ctx0, { 0, scr0_size, scr0, }); - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); struct ggml_tensor * KQ_scaled = @@ -1330,8 +1329,6 @@ bool sam_encode_image( cur = ggml_add(ctx0, inpL, cur); - ggml_set_scratch(ctx0, { 0, scr1_size, scr1, }); - struct ggml_tensor * inpFF = cur; // feed-forward network @@ -1373,8 +1370,6 @@ bool sam_encode_image( inpL = ggml_add(ctx0, cur, inpFF); } - ggml_set_scratch(ctx0, { 0, scr0_size, scr0, }); - cur = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 2, 0, 1, 3)); cur = ggml_conv_2d_sk_p0(ctx0, enc.neck_conv_0, cur); @@ -1385,23 +1380,24 @@ bool sam_encode_image( cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_1_w, enc.neck_norm_1_b, hparams.eps); - // TODO: avoid copy cur = ggml_cpy(ctx0, cur, state.embd_img); - ggml_set_name(cur, "check"); - - ggml_set_scratch(ctx0, { 0, 0, nullptr, }); - - // run the computation - ggml_build_forward_expand(&gf, cur); - ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); + ggml_build_forward_expand(gf, cur); + ggml_disconnect_node_from_graph(state.embd_img); //ggml_graph_print(&gf); ggml_free(ctx0); - return true; + + return gf; } + +struct prompt_encoder_result { + struct ggml_tensor * embd_prompt_sparse = {}; + struct ggml_tensor * embd_prompt_dense = {}; +}; + // encode a prompt // // - points @@ -1410,29 +1406,18 @@ bool sam_encode_image( // // TODO: currently just encode a single point for simplicity // -bool sam_encode_prompt( +prompt_encoder_result sam_encode_prompt( const sam_model & model, + struct ggml_context * ctx0, + struct ggml_cgraph * gf, sam_state & state, int nx, int ny, - sam_point point, - int n_threads) { + sam_point point) { const auto & hparams = model.hparams; const auto & enc = model.enc_prompt; - static size_t buf_size = 256u*1024*1024; - static void * buf = malloc(buf_size); - - struct ggml_init_params params = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ buf, - /*.no_alloc =*/ false, - }; - - struct ggml_context * ctx0 = ggml_init(params); - struct ggml_cgraph gf = {}; - // transform points // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py#L276 { @@ -1447,12 +1432,11 @@ bool sam_encode_prompt( point.y = point.y*(float(ny_new)/ny) + 0.5f; } - printf("point: %f %f\n", point.x, point.y); - struct ggml_tensor * inp = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2, 2); - // set the input by converting the [0, 1] coordinates to [-1, 1] - { + ggml_allocr_alloc(state.allocr, inp); + if (!ggml_allocr_is_measure(state.allocr)) { + // set the input by converting the [0, 1] coordinates to [-1, 1] float * data = (float *) inp->data; data[0] = 2.0f*(point.x / hparams.n_img_size()) - 1.0f; @@ -1468,7 +1452,6 @@ bool sam_encode_prompt( cur = ggml_scale(ctx0, cur, ggml_new_f32(ctx0, 2.0f*M_PI)); - // concat // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L192 { @@ -1477,44 +1460,38 @@ bool sam_encode_prompt( cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, t_sin->ne[0] + t_cos->ne[0], cur->ne[1]); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, t_sin, ggml_view_2d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], cur->nb[1], 0))); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, t_cos, ggml_view_2d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], cur->nb[1], t_sin->nb[1]))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_sin, ggml_view_2d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], cur->nb[1], 0))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_cos, ggml_view_2d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], cur->nb[1], t_sin->nb[1]))); // overwrite label == -1 with not_a_point_embed.weight // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L86 // TODO: extend for multiple points - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, enc.not_a_pt_embd_w, ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], cur->nb[1]))); - - // add point_embeddings[1] to label == 1 - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L90 - ggml_build_forward_expand(&gf, ggml_add_inplace(ctx0, ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], 0), enc.pt_embd[1])); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, enc.not_a_pt_embd_w, ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], cur->nb[1]))); } - // TODO: avoid copy - cur = ggml_cpy(ctx0, cur, state.embd_prompt_sparse); + // add point_embeddings[1] to label == 1 + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L90 + struct ggml_tensor * v = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], 0); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_add_inplace(ctx0, v, enc.pt_embd[1]), v)); - ggml_build_forward_expand(&gf, cur); + struct ggml_tensor * embd_prompt_sparse = cur; + ggml_build_forward_expand(gf, embd_prompt_sparse); - cur = ggml_repeat(ctx0, + struct ggml_tensor * embd_prompt_dense = ggml_repeat(ctx0, ggml_cont(ctx0, ggml_view_3d(ctx0, enc.no_mask_embd_w, 1, 1, enc.no_mask_embd_w->ne[0], enc.no_mask_embd_w->nb[0], enc.no_mask_embd_w->nb[0], 0)), - state.embd_prompt_dense); - - // TODO: avoid copy - cur = ggml_cpy(ctx0, cur, state.embd_prompt_dense); - - ggml_build_forward_expand(&gf, cur); + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hparams.n_img_embd(), hparams.n_img_embd(), hparams.n_enc_out_chans)); - ggml_set_name(cur, "check"); - - // run the computation - ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); + ggml_build_forward_expand(gf, embd_prompt_dense); //printf("used_mem = %zu\n", ggml_used_mem(ctx0)); - ggml_free(ctx0); - return true; + prompt_encoder_result res; + res.embd_prompt_sparse = embd_prompt_sparse; + res.embd_prompt_dense = embd_prompt_dense; + + return res; } struct ggml_tensor* sam_decode_mask_transformer_attn( @@ -1546,9 +1523,6 @@ struct ggml_tensor* sam_decode_mask_transformer_attn( ggml_repeat(ctx0, attn.v_b, Vcur), Vcur); - // TODO: use stratch memory - // ggml_set_scratch(ctx0, { 0, scr0_size, scr0, }); - struct ggml_tensor * Q = {}; struct ggml_tensor * K = {}; struct ggml_tensor * V = {}; @@ -1617,46 +1591,33 @@ struct ggml_tensor * sam_decode_mask_mlp_relu_3( } bool sam_decode_mask( - const sam_model & model, - sam_state & state, - int n_threads) { + const sam_model & model, + const prompt_encoder_result & prompt, + struct ggml_tensor * pe_img, + struct ggml_context * ctx0, + struct ggml_cgraph * gf, + sam_state & state) { const auto & hparams = model.hparams; const auto & dec = model.dec; const int n_img_embd = hparams.n_img_embd(); - static size_t buf_size = 384u*1024*1024; - static void * buf = malloc(buf_size); - - struct ggml_init_params params = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ buf, - /*.no_alloc =*/ false, - }; - - struct ggml_context * ctx0 = ggml_init(params); - struct ggml_cgraph gf = {}; - - // print_t_f32("embd_img", state.embd_img); - // print_t_f32("embd_prompt_dense", state.embd_prompt_dense); - // print_t_f32("embd_prompt_sparse", state.embd_prompt_sparse); - // print_t_f32("pe_img_dense", state.pe_img_dense); - struct ggml_tensor * tokens = {}; { // Concatenate output tokens // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L120 - const auto& sparse = state.embd_prompt_sparse; + const auto& sparse = prompt.embd_prompt_sparse; tokens = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, dec.iou_token_w->ne[0], dec.iou_token_w->ne[1] + dec.mask_tokens_w->ne[1] + sparse->ne[1], sparse->ne[2]); const size_t offsets[3] = { 0, dec.iou_token_w->ne[1]*tokens->nb[1], dec.iou_token_w->ne[1]*tokens->nb[1] + dec.mask_tokens_w->ne[1]*tokens->nb[1] }; - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, dec.iou_token_w, ggml_view_2d(ctx0, tokens, tokens->ne[0], dec.iou_token_w->ne[1], tokens->nb[1], offsets[0]))); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, dec.mask_tokens_w, ggml_view_2d(ctx0, tokens, tokens->ne[0], dec.mask_tokens_w->ne[1], tokens->nb[1], offsets[1]))); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, sparse, ggml_view_2d(ctx0, tokens, tokens->ne[0], sparse->ne[1], tokens->nb[1], offsets[2]))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, dec.iou_token_w, ggml_view_2d(ctx0, tokens, tokens->ne[0], dec.iou_token_w->ne[1], tokens->nb[1], offsets[0]))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, dec.mask_tokens_w, ggml_view_2d(ctx0, tokens, tokens->ne[0], dec.mask_tokens_w->ne[1], tokens->nb[1], offsets[1]))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, sparse, ggml_view_2d(ctx0, tokens, tokens->ne[0], sparse->ne[1], tokens->nb[1], offsets[2]))); // TODO: Sparse prompt embeddings can have more than one point } + struct ggml_tensor * src = {}; struct ggml_tensor * pos_src = {}; int srcNE[4] = { 0, 0, 0, 0 }; @@ -1664,11 +1625,12 @@ bool sam_decode_mask( // Expand per-image data in the batch direction to be per-mask // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L125 src = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, state.embd_img->ne[0], state.embd_img->ne[1], state.embd_img->ne[2], tokens->ne[2]); + src = ggml_add(ctx0, ggml_repeat(ctx0, state.embd_img, src), - state.embd_prompt_dense); + prompt.embd_prompt_dense); srcNE[0] = src->ne[0]; srcNE[1] = src->ne[1]; @@ -1687,11 +1649,10 @@ bool sam_decode_mask( src->nb[3], 0), 1, 0, 2, 3)); - ggml_build_forward_expand(&gf, src); - pos_src = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, state.pe_img_dense->ne[0], state.pe_img_dense->ne[1], state.pe_img_dense->ne[2], tokens->ne[2]); + pos_src = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, pe_img->ne[0], pe_img->ne[1], pe_img->ne[2], tokens->ne[2]); pos_src = ggml_repeat(ctx0, - state.pe_img_dense, + pe_img, pos_src); // flatten & permute @@ -1706,7 +1667,6 @@ bool sam_decode_mask( pos_src->nb[3], 0), 1, 0, 2, 3)); - ggml_build_forward_expand(&gf, pos_src); } struct ggml_tensor * queries = tokens; @@ -1811,6 +1771,7 @@ bool sam_decode_mask( ggml_repeat(ctx0, dec.transformer_norm_final_b, queries)); } + struct ggml_tensor * iou_pred = ggml_view_2d(ctx0, queries, queries->ne[0], queries->ne[2], queries->nb[2], 0); const int num_mask_tokens = 4; // num_multimask_outputs + 1 struct ggml_tensor * mask_tokens_out = ggml_view_3d(ctx0, queries, queries->ne[0], num_mask_tokens, queries->ne[2], queries->nb[1], num_mask_tokens*queries->nb[1], queries->nb[1]); @@ -1819,11 +1780,11 @@ bool sam_decode_mask( // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L136 keys = ggml_cont(ctx0, ggml_transpose(ctx0, keys)); keys = ggml_view_4d(ctx0, keys, srcNE[0], srcNE[1], srcNE[2], srcNE[3], srcNE[0]*keys->nb[0], keys->nb[1], keys->nb[2], 0); - struct ggml_tensor * upscaled_embedding = {}; { // ConvTranspose2d keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_0_w, keys, 2); + ggml_allocr_alloc(state.allocr, keys); // TODO: This alloc shouldn't be needed keys = ggml_add(ctx0, keys, ggml_repeat(ctx0, ggml_reshape_3d(ctx0, dec.output_upscaling_0_b, 1, 1, dec.output_upscaling_0_b->ne[0]), keys)); @@ -1835,6 +1796,7 @@ bool sam_decode_mask( // ConvTranspose2d keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_3_w, keys, 2); + ggml_allocr_alloc(state.allocr, keys); // TODO: This alloc shouldn't be needed keys = ggml_add(ctx0, ggml_repeat(ctx0, ggml_reshape_3d(ctx0, dec.output_upscaling_3_b, 1, 1, dec.output_upscaling_3_b->ne[0]), keys), keys); @@ -1845,11 +1807,12 @@ bool sam_decode_mask( } struct ggml_tensor * hyper_in = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_img_embd/2, num_mask_tokens, mask_tokens_out->ne[2]); + for (int i = 0; i < num_mask_tokens; ++i) { const auto& mlp = dec.output_hypernet_mlps[i]; struct ggml_tensor * in = ggml_view_2d(ctx0, mask_tokens_out, mask_tokens_out->ne[0], mask_tokens_out->ne[2], mask_tokens_out->nb[1], i*mask_tokens_out->nb[1]); struct ggml_tensor * out = sam_decode_mask_mlp_relu_3(in, mlp.w_0, mlp.b_0, mlp.w_1, mlp.b_1, mlp.w_2, mlp.b_2, ctx0); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, out, ggml_view_2d(ctx0, hyper_in, hyper_in->ne[0], hyper_in->ne[2], hyper_in->nb[1], i*hyper_in->nb[1]))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, out, ggml_view_2d(ctx0, hyper_in, hyper_in->ne[0], hyper_in->ne[2], hyper_in->nb[1], i*hyper_in->nb[1]))); } struct ggml_tensor * masks = ggml_mul_mat(ctx0, hyper_in, upscaled_embedding); @@ -1858,39 +1821,21 @@ bool sam_decode_mask( // Generate mask quality predictions // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L146 - state.iou_predictions = sam_decode_mask_mlp_relu_3(iou_pred, dec.iou_prediction_head_0_w, dec.iou_prediction_head_0_b, dec.iou_prediction_head_1_w, dec.iou_prediction_head_1_b, dec.iou_prediction_head_2_w, dec.iou_prediction_head_2_b, ctx0); + iou_pred = sam_decode_mask_mlp_relu_3(iou_pred, dec.iou_prediction_head_0_w, dec.iou_prediction_head_0_b, dec.iou_prediction_head_1_w, dec.iou_prediction_head_1_b, dec.iou_prediction_head_2_w, dec.iou_prediction_head_2_b, ctx0); // Select the correct mask or masks for output // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L101 - state.iou_predictions = ggml_view_1d(ctx0, state.iou_predictions, state.iou_predictions->ne[0] - 1, state.iou_predictions->nb[0]); - - state.low_res_masks = ggml_view_4d(ctx0, masks, masks->ne[0], masks->ne[1], masks->ne[2] - 1, masks->ne[3], - masks->nb[0], masks->nb[1], masks->nb[2], - masks->nb[2] /* offset*/); - // ggml_set_name(queries, "queries"); - // ggml_set_name(upscaled_embedding, "upscaled_embedding"); - // ggml_set_name(state.low_res_masks, "low_res_masks"); - // ggml_set_name(state.iou_predictions, "iou_predictions"); - // ggml_set_name(hyper_in, "hyper_in"); - - ggml_build_forward_expand(&gf, state.iou_predictions); - ggml_build_forward_expand(&gf, state.low_res_masks); - - // run the computation - ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); - - // auto * t = ggml_get_tensor(ctx0, "queries"); - // print_t_f32("queries", t); - // t = ggml_get_tensor(ctx0, "upscaled_embedding"); - // print_t_f32("upscaled_embedding", t); - // t = ggml_get_tensor(ctx0, "low_res_masks"); - // print_t_f32("low_res_masks", t); - // t = ggml_get_tensor(ctx0, "iou_predictions"); - // print_t_f32("iou_predictions", t); - // t = ggml_get_tensor(ctx0, "hyper_in"); - // print_t_f32("hyper_in", t); + iou_pred = ggml_cpy(state.ctx, ggml_view_1d(ctx0, iou_pred, iou_pred->ne[0] - 1, iou_pred->nb[0]), state.iou_predictions); + masks = ggml_view_4d(ctx0, masks, masks->ne[0], masks->ne[1], masks->ne[2] - 1, masks->ne[3], + masks->nb[1], masks->nb[2], masks->nb[3], masks->nb[2] /* offset*/); + masks = ggml_cpy(state.ctx, masks, state.low_res_masks); + + ggml_build_forward_expand(gf, masks); + ggml_build_forward_expand(gf, iou_pred); + + ggml_disconnect_node_from_graph(state.low_res_masks); + ggml_disconnect_node_from_graph(state.iou_predictions); - ggml_free(ctx0); return true; } @@ -1929,6 +1874,7 @@ bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state for (int i = 0; i < ne2; ++i) { if (iou_threshold > 0.f && iou_data[i] < iou_threshold) { + printf("Skipping mask %d with iou %f below threshold %f\n", i, iou_data[i], iou_threshold); continue; // Filtering masks with iou below the threshold } @@ -2033,6 +1979,7 @@ bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state const float stability_score = float(intersections) / float(unions); if (stability_score_threshold > 0.f && stability_score < stability_score_threshold) { + printf("Skipping mask %d with stability score %f below threshold %f\n", i, stability_score, stability_score_threshold); continue; // Filtering masks with stability score below the threshold } @@ -2050,12 +1997,105 @@ bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state return true; } +struct ggml_cgraph * sam_build_fast_graph( + const sam_model & model, + sam_state & state, + int nx, + int ny, + sam_point point) { + + struct ggml_init_params ggml_params = { + /*.mem_size =*/ state.buf_compute_fast.size(), + /*.mem_buffer =*/ state.buf_compute_fast.data(), + /*.no_alloc =*/ true, // skip allocating as we use ggml_alloc to allocate exact memory requirements + }; + + struct ggml_context * ctx0 = ggml_init(ggml_params); + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + prompt_encoder_result enc_res = sam_encode_prompt(model, ctx0, gf, state, nx, ny, point); + if (!enc_res.embd_prompt_sparse || !enc_res.embd_prompt_dense) { + fprintf(stderr, "%s: failed to encode prompt\n", __func__); + return {}; + } + + struct ggml_tensor * pe_img_dense = sam_fill_dense_pe(model, ctx0, gf, state); + if (!pe_img_dense) { + fprintf(stderr, "%s: failed to get dense positional encoding\n", __func__); + return {}; + } + + if (!sam_decode_mask(model, enc_res, pe_img_dense, ctx0, gf, state)) { + fprintf(stderr, "%s: failed to decode mask\n", __func__); + return {}; + } + + ggml_free(ctx0); + + return gf; +} +struct sam_params { + int32_t seed = -1; // RNG seed + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + + std::string model = "models/sam-vit-b/ggml-model-f16.bin"; // model path + std::string fname_inp = "img.jpg"; + std::string fname_out = "img.out"; +}; + +void sam_print_usage(int argc, char ** argv, const sam_params & params) { + fprintf(stderr, "usage: %s [options]\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); + fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); + fprintf(stderr, " -m FNAME, --model FNAME\n"); + fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); + fprintf(stderr, " -i FNAME, --inp FNAME\n"); + fprintf(stderr, " input file (default: %s)\n", params.fname_inp.c_str()); + fprintf(stderr, " -o FNAME, --out FNAME\n"); + fprintf(stderr, " output file (default: %s)\n", params.fname_out.c_str()); + fprintf(stderr, "\n"); +} + +bool sam_params_parse(int argc, char ** argv, sam_params & params) { + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "-s" || arg == "--seed") { + params.seed = std::stoi(argv[++i]); + } else if (arg == "-t" || arg == "--threads") { + params.n_threads = std::stoi(argv[++i]); + } else if (arg == "-m" || arg == "--model") { + params.model = argv[++i]; + } else if (arg == "-i" || arg == "--inp") { + params.fname_inp = argv[++i]; + } else if (arg == "-o" || arg == "--out") { + params.fname_out = argv[++i]; + } else if (arg == "-h" || arg == "--help") { + sam_print_usage(argc, argv, params); + exit(0); + } else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + sam_print_usage(argc, argv, params); + exit(0); + } + } + + return true; +} + int main(int argc, char ** argv) { const int64_t t_main_start_us = ggml_time_us(); sam_params params; params.model = "models/sam-vit-b/ggml-model-f16.bin"; + sam_model model; + sam_state state; + int64_t t_load_us = 0; + if (sam_params_parse(argc, argv, params) == false) { return 1; } @@ -2063,7 +2103,6 @@ int main(int argc, char ** argv) { if (params.seed < 0) { params.seed = time(NULL); } - fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); // load the image @@ -2072,7 +2111,6 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: failed to load image from '%s'\n", __func__, params.fname_inp.c_str()); return 1; } - fprintf(stderr, "%s: loaded image '%s' (%d x %d)\n", __func__, params.fname_inp.c_str(), img0.nx, img0.ny); // preprocess to f32 @@ -2081,24 +2119,8 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: failed to preprocess image\n", __func__); return 1; } - fprintf(stderr, "%s: preprocessed image (%d x %d)\n", __func__, img1.nx, img1.ny); -#if 0 - { - const int n = 128; - fprintf(stderr, "%s: first %d diagonal pixels:\n", __func__, n); - for (int i = 0; i < n; i++) { - const int ii = i*img1.nx + i; - fprintf(stderr, "%s: %d: %f %f %f\n", __func__, i, img1.data[3*ii + 0], img1.data[3*ii + 1], img1.data[3*ii + 2]); - } - } -#endif - - int64_t t_load_us = 0; - - sam_model model; - sam_state state; // load the model { @@ -2115,48 +2137,97 @@ int main(int argc, char ** argv) { { static size_t buf_size = 256u*1024*1024; - struct ggml_init_params params = { + struct ggml_init_params ggml_params = { /*.mem_size =*/ buf_size, /*.mem_buffer =*/ NULL, /*.no_alloc =*/ false, }; - state.ctx = ggml_init(params); + state.ctx = ggml_init(ggml_params); state.embd_img = ggml_new_tensor_3d(state.ctx, GGML_TYPE_F32, model.hparams.n_img_embd(), model.hparams.n_img_embd(), model.hparams.n_enc_out_chans); - // TODO: should depend on the number of points / boxes / etc - state.embd_prompt_sparse = ggml_new_tensor_2d(state.ctx, GGML_TYPE_F32, model.hparams.n_enc_out_chans, 2); + state.low_res_masks = ggml_new_tensor_3d(state.ctx, GGML_TYPE_F32, + model.hparams.n_enc_out_chans, model.hparams.n_enc_out_chans, 3); - state.embd_prompt_dense = ggml_new_tensor_3d(state.ctx, GGML_TYPE_F32, - model.hparams.n_img_embd(), model.hparams.n_img_embd(), model.hparams.n_enc_out_chans); - - state.pe_img_dense = ggml_new_tensor_3d(state.ctx, GGML_TYPE_F32, - model.hparams.n_img_embd(), model.hparams.n_img_embd(), model.hparams.n_enc_out_chans); + state.iou_predictions = ggml_new_tensor_1d(state.ctx, GGML_TYPE_F32, 3); } - if (!sam_fill_dense_pe(model, state, params.n_threads)) { - fprintf(stderr, "%s: failed to get dense positional encoding\n", __func__); - return 1; - } - if (!sam_encode_image(model, state, img1, params.n_threads)) { - fprintf(stderr, "%s: failed to encode image\n", __func__); - return 1; - } + static const size_t tensor_alignment = 32; + { + state.buf_compute_img_enc.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); + state.allocr = ggml_allocr_new_measure(tensor_alignment); + struct ggml_cgraph * gf_measure = sam_encode_image(model, state, img1); + if (!gf_measure) { + fprintf(stderr, "%s: failed to encode image\n", __func__); + return 1; + } - // TODO: user input - const sam_point pt = { 414.375f, 162.796875f, }; + size_t alloc_size = ggml_allocr_alloc_graph(state.allocr, gf_measure) + tensor_alignment; + ggml_allocr_free(state.allocr); - if (!sam_encode_prompt(model, state, img0.nx, img0.ny, pt, params.n_threads)) { - fprintf(stderr, "%s: failed to encode prompt\n", __func__); - return 1; + // recreate allocator with exact memory requirements + state.buf_alloc_img_enc.resize(alloc_size); + state.allocr = ggml_allocr_new(state.buf_alloc_img_enc.data(), state.buf_alloc_img_enc.size(), tensor_alignment); + + // compute the graph with the measured exact memory requirements from above + ggml_allocr_reset(state.allocr); + + struct ggml_cgraph * gf = sam_encode_image(model, state, img1); + if (!gf) { + fprintf(stderr, "%s: failed to encode image\n", __func__); + return 1; + } + + ggml_allocr_alloc_graph(state.allocr, gf); + + ggml_graph_compute_helper(state.work_buffer, gf, params.n_threads); + + print_t_f32("embd_img", state.embd_img); + + ggml_allocr_free(state.allocr); + state.allocr = NULL; + state.work_buffer.clear(); } + { + state.buf_compute_fast.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); + state.allocr = ggml_allocr_new_measure(tensor_alignment); + + // TODO: user input + const sam_point pt = { 414.375f, 162.796875f, }; + // measure memory requirements for the graph + struct ggml_cgraph * gf_measure = sam_build_fast_graph(model, state, img0.nx, img0.ny, pt); + if (!gf_measure) { + fprintf(stderr, "%s: failed to build fast graph to measure\n", __func__); + return 1; + } - if (!sam_decode_mask(model, state, params.n_threads)) { - fprintf(stderr, "%s: failed to decode mask\n", __func__); - return 1; + size_t alloc_size = ggml_allocr_alloc_graph(state.allocr, gf_measure) + tensor_alignment; + ggml_allocr_free(state.allocr); + + // recreate allocator with exact memory requirements + state.buf_alloc_fast.resize(alloc_size); + state.allocr = ggml_allocr_new(state.buf_alloc_fast.data(), state.buf_alloc_fast.size(), tensor_alignment); + + // compute the graph with the measured exact memory requirements from above + ggml_allocr_reset(state.allocr); + + struct ggml_cgraph * gf = sam_build_fast_graph(model, state, img0.nx, img0.ny, pt); + if (!gf) { + fprintf(stderr, "%s: failed to build fast graph\n", __func__); + return 1; + } + + ggml_allocr_alloc_graph(state.allocr, gf); + + ggml_graph_compute_helper(state.work_buffer, gf, params.n_threads); + + //print_t_f32("iou_predictions", state.iou_predictions); + //print_t_f32("low_res_masks", state.low_res_masks); + ggml_allocr_free(state.allocr); + state.allocr = NULL; } if (!sam_write_masks(model.hparams, img0.nx, img0.ny, state)) { diff --git a/src/ggml-alloc.c b/src/ggml-alloc.c index 140e9a2a..856a4cdb 100644 --- a/src/ggml-alloc.c +++ b/src/ggml-alloc.c @@ -378,6 +378,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) { case GGML_OP_SET: case GGML_OP_SOFT_MAX: case GGML_OP_CONT: + case GGML_OP_ADD_REL_POS: return true; default: diff --git a/src/ggml.c b/src/ggml.c index dadb3075..b199bb1a 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -7464,8 +7464,6 @@ static struct ggml_tensor * ggml_add_rel_pos_impl( } struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_set_op_params_i32(result, 0, inplace ? 1 : 0); - result->op = GGML_OP_ADD_REL_POS; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; @@ -14978,8 +14976,11 @@ static void ggml_compute_forward_add_rel_pos_f32( const struct ggml_tensor * src1, const struct ggml_tensor * src2, struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == dst->nb[0] && src0->nb[1] == dst->nb[1] + && src0->nb[2] == dst->nb[2] && src0->nb[3] == dst->nb[3]); - const bool inplace = (bool) ((int32_t *) dst->op_params)[0]; + const bool inplace = dst->data == src0->data; if (!inplace && params->type == GGML_TASK_INIT) { memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst)); return;