From: Taylor Date: Thu, 8 May 2025 15:56:59 +0000 (-0500) Subject: sam : support box prompt (#1206) X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=153e8bcb62811d608f5cddc13a8bd1c5de7eab25;p=pkg%2Fggml%2Fsources%2Fggml sam : support box prompt (#1206) * SAM: parse box prompt * support box prompt * sam: rename main.cpp to sam.cpp * pass sam_prompt as const & * support single mask output (default multimask) * Update examples/sam/sam.cpp Co-authored-by: Georgi Gerganov * Update examples/sam/sam.cpp Co-authored-by: Georgi Gerganov * Update examples/sam/sam.cpp Co-authored-by: Georgi Gerganov * Update examples/sam/sam.cpp Co-authored-by: Georgi Gerganov * Update examples/sam/sam.cpp Co-authored-by: Georgi Gerganov * removed debug print * test sam example: box prompt and single mask output --------- Co-authored-by: Georgi Gerganov --- diff --git a/ci/run.sh b/ci/run.sh index 1738aac2..66f15cd7 100644 --- a/ci/run.sh +++ b/ci/run.sh @@ -254,10 +254,16 @@ function gg_run_sam { python3 ../examples/sam/convert-pth-to-ggml.py ${path_models}/sam_vit_b_01ec64.pth ${path_models}/ 1 + # Test default parameters (time ./bin/sam -m ${model_f16} -i ${img_0} ) 2>&1 | tee -a $OUT/${ci}-main.log - + grep -q "point prompt" $OUT/${ci}-main.log grep -q "bbox (371, 436), (144, 168)" $OUT/${ci}-main.log + # Test box prompt and single mask output + (time ./bin/sam -m ${model_f16} -i ${img_0} -b 368,144,441,173 -sm) 2>&1 | tee -a $OUT/${ci}-main.log + grep -q "box prompt" $OUT/${ci}-main.log + grep -q "bbox (370, 439), (144, 169)" $OUT/${ci}-main.log + set +e } diff --git a/examples/sam/CMakeLists.txt b/examples/sam/CMakeLists.txt index fa302bb4..5033618a 100644 --- a/examples/sam/CMakeLists.txt +++ b/examples/sam/CMakeLists.txt @@ -2,7 +2,7 @@ # sam set(TEST_TARGET sam) -add_executable(${TEST_TARGET} main.cpp) +add_executable(${TEST_TARGET} sam.cpp) target_link_libraries(${TEST_TARGET} PRIVATE ggml common) # diff --git a/examples/sam/main.cpp b/examples/sam/main.cpp deleted file mode 100644 index eeb7e5cd..00000000 --- a/examples/sam/main.cpp +++ /dev/null @@ -1,2233 +0,0 @@ -#define _USE_MATH_DEFINES // for M_PI -#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows - -#include "ggml.h" -#include "ggml-cpu.h" -#include "ggml-alloc.h" -#include "ggml-backend.h" -#define STB_IMAGE_IMPLEMENTATION -#include "stb_image.h" -#define STB_IMAGE_WRITE_IMPLEMENTATION -#include "stb_image_write.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined(_MSC_VER) -#pragma warning(disable: 4244 4267) // possible loss of data -#endif - -// default hparams (ViT-B SAM) -struct sam_hparams { - int32_t n_enc_state = 768; - int32_t n_enc_layer = 12; - int32_t n_enc_head = 12; - int32_t n_enc_out_chans = 256; - int32_t n_pt_embd = 4; - int32_t n_dec_heads = 8; - int32_t ftype = 1; - float mask_threshold = 0.f; - float iou_threshold = 0.88f; - float stability_score_threshold = 0.95f; - float stability_score_offset = 1.0f; - float eps = 1e-6f; - float eps_decoder_transformer = 1e-5f; - - int32_t n_enc_head_dim() const { return n_enc_state / n_enc_head; } - int32_t n_img_size() const { return 1024; } - int32_t n_window_size() const { return 14; } - int32_t n_patch_size() const { return 16; } - int32_t n_img_embd() const { return n_img_size() / n_patch_size(); } - - std::vector global_attn_indices() const { - switch (n_enc_state) { - case 768: return { 2, 5, 8, 11 }; - case 1024: return { 5, 11, 17, 23 }; - case 1280: return { 7, 15, 23, 31 }; - default: - { - fprintf(stderr, "%s: unsupported n_enc_state = %d\n", __func__, n_enc_state); - } break; - }; - - return {}; - } - - bool is_global_attn(int32_t layer) const { - const auto indices = global_attn_indices(); - - for (const auto & idx : indices) { - if (layer == idx) { - return true; - } - } - - return false; - } -}; - -struct sam_layer_enc { - struct ggml_tensor * norm1_w; - struct ggml_tensor * norm1_b; - - struct ggml_tensor * rel_pos_w; - struct ggml_tensor * rel_pos_h; - - struct ggml_tensor * qkv_w; - struct ggml_tensor * qkv_b; - - struct ggml_tensor * proj_w; - struct ggml_tensor * proj_b; - - struct ggml_tensor * norm2_w; - struct ggml_tensor * norm2_b; - - struct ggml_tensor * mlp_lin1_w; - struct ggml_tensor * mlp_lin1_b; - - struct ggml_tensor * mlp_lin2_w; - struct ggml_tensor * mlp_lin2_b; -}; - -struct sam_encoder_image { - struct ggml_tensor * pe; - - struct ggml_tensor * proj_w; - struct ggml_tensor * proj_b; - - struct ggml_tensor * neck_conv_0; - struct ggml_tensor * neck_norm_0_w; - struct ggml_tensor * neck_norm_0_b; - struct ggml_tensor * neck_conv_1; - struct ggml_tensor * neck_norm_1_w; - struct ggml_tensor * neck_norm_1_b; - - std::vector layers; -}; - -struct sam_encoder_prompt { - struct ggml_tensor * pe; - - struct ggml_tensor * not_a_pt_embd_w; - std::vector pt_embd; - - struct ggml_tensor * no_mask_embd_w; - //std::vector mask_down_w; - //std::vector mask_down_b; -}; - -struct sam_layer_dec_transformer_attn { - // q_proj - struct ggml_tensor * q_w; - struct ggml_tensor * q_b; - - // k_proj - struct ggml_tensor * k_w; - struct ggml_tensor * k_b; - - // v_proj - struct ggml_tensor * v_w; - struct ggml_tensor * v_b; - - // out_proj - struct ggml_tensor * out_w; - struct ggml_tensor * out_b; -}; - -struct sam_layer_dec_transformer { - sam_layer_dec_transformer_attn self_attn; - - // norm1 - struct ggml_tensor * norm1_w; - struct ggml_tensor * norm1_b; - - sam_layer_dec_transformer_attn cross_attn_token_to_img; - - // norm2 - struct ggml_tensor * norm2_w; - struct ggml_tensor * norm2_b; - - // mlp.lin1 - struct ggml_tensor * mlp_lin1_w; - struct ggml_tensor * mlp_lin1_b; - - // mlp.lin2 - struct ggml_tensor * mlp_lin2_w; - struct ggml_tensor * mlp_lin2_b; - - // norm3 - struct ggml_tensor * norm3_w; - struct ggml_tensor * norm3_b; - - // norm4 - struct ggml_tensor * norm4_w; - struct ggml_tensor * norm4_b; - - sam_layer_dec_transformer_attn cross_attn_img_to_token; -}; - -struct sam_layer_dec_output_hypernet_mlps { - // mlps_*.layers.0 - struct ggml_tensor * w_0; - struct ggml_tensor * b_0; - - // mlps_*.layers.1 - struct ggml_tensor * w_1; - struct ggml_tensor * b_1; - - // mlps_*.layers.2 - struct ggml_tensor * w_2; - struct ggml_tensor * b_2; -}; - -struct sam_decoder_mask { - std::vector transformer_layers; - - // trasnformer.final_attn_token_to_image - sam_layer_dec_transformer_attn transformer_final_attn_token_to_img; - - // transformer.norm_final - struct ggml_tensor * transformer_norm_final_w; - struct ggml_tensor * transformer_norm_final_b; - - // output_upscaling.0 - struct ggml_tensor * output_upscaling_0_w; - struct ggml_tensor * output_upscaling_0_b; - - // output_upscaling.1 - struct ggml_tensor * output_upscaling_1_w; - struct ggml_tensor * output_upscaling_1_b; - - // output_upscaling.3 - struct ggml_tensor * output_upscaling_3_w; - struct ggml_tensor * output_upscaling_3_b; - - // output_hypernetworks_mlps - std::vector output_hypernet_mlps; - - // iou_prediction_head.0 - struct ggml_tensor * iou_prediction_head_0_w; - struct ggml_tensor * iou_prediction_head_0_b; - - // iou_prediction_head.1 - struct ggml_tensor * iou_prediction_head_1_w; - struct ggml_tensor * iou_prediction_head_1_b; - - // iou_prediction_head.2 - struct ggml_tensor * iou_prediction_head_2_w; - struct ggml_tensor * iou_prediction_head_2_b; - - // iou_token.weight - struct ggml_tensor * iou_token_w; - - // mask_tokens.weight - struct ggml_tensor * mask_tokens_w; -}; - - -struct sam_state { - struct ggml_tensor * embd_img; - - struct ggml_tensor * low_res_masks; - struct ggml_tensor * iou_predictions; - - //struct ggml_tensor * tmp_save = {}; - - struct ggml_context * ctx; - - // buffer for `ggml_graph_plan.work_data` - std::vector work_buffer; - // buffers to evaluate the model - std::vector buf_compute_img_enc; - - std::vector buf_compute_fast; - - ggml_gallocr_t allocr = {}; -}; - -// void save_tensor(sam_state& state, struct ggml_tensor * t, struct ggml_cgraph * gf) { -// if (!state.tmp_save) { -// state.tmp_save = ggml_new_tensor(state.ctx, t->type, t->n_dims, t->ne); -// } -// struct ggml_tensor * tmp0 = ggml_cpy(state.ctx, t, state.tmp_save); -// ggml_build_forward_expand(gf, tmp0); -// } - -struct sam_model { - sam_hparams hparams; - - sam_encoder_image enc_img; - sam_encoder_prompt enc_prompt; - sam_decoder_mask dec; - - // - struct ggml_context * ctx; - std::map tensors; -}; - -struct sam_point { - float x; - float y; -}; - -// RGB uint8 image -struct sam_image_u8 { - int nx; - int ny; - - std::vector data; -}; - -// RGB float32 image -// Memory layout: RGBRGBRGB... -struct sam_image_f32 { - int nx; - int ny; - - std::vector data; -}; - -struct sam_params { - int32_t seed = -1; // RNG seed - int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); - - std::string model = "models/sam-vit-b/ggml-model-f16.bin"; // model path - std::string fname_inp = "img.jpg"; - std::string fname_out = "img.out"; - float mask_threshold = 0.f; - float iou_threshold = 0.88f; - float stability_score_threshold = 0.95f; - float stability_score_offset = 1.0f; - float eps = 1e-6f; - float eps_decoder_transformer = 1e-5f; - sam_point pt = { 414.375f, 162.796875f, }; -}; - -void print_t_f32(const char* title, struct ggml_tensor * t, int n = 10) { - printf("%s\n", title); - float * data = (float *)t->data; - printf("dims: % " PRId64 " % " PRId64 " % " PRId64 " % " PRId64 " f32\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]); - printf("First & Last %d elements:\n", n); - for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) { - printf("%.5f ", data[i]); - if (i != 0 && i % t->ne[0] == 0) { - printf("\n"); - } - } - printf("\n"); - for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) { - printf("%.5f ", data[ggml_nelements(t) - n + i]); - if ((ggml_nelements(t) - n + i) % t->ne[0] == 0) { - printf("\n"); - } - } - printf("\n"); - double sum = 0.0; - for (int i = 0; i < ggml_nelements(t); i++) { - sum += data[i]; - } - printf("sum: %f\n\n", sum); -} - -static void ggml_disconnect_node_from_graph(ggml_tensor * t) { - t->op = GGML_OP_NONE; - for (int i = 0; i < GGML_MAX_SRC; i++) { - t->src[i] = NULL; - } -} - -static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads) { - struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr); - - if (plan.work_size > 0) { - buf.resize(plan.work_size); - plan.work_data = buf.data(); - } - - ggml_graph_compute(graph, &plan); -} - -static void ggml_sam_sin(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) { - GGML_ASSERT(userdata == NULL); - GGML_ASSERT(ggml_are_same_shape(dst, src)); - GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_is_contiguous(src)); - - const float * src_data = ggml_get_data_f32(src); - float * dst_data = ggml_get_data_f32(dst); - - const int ne = (int)ggml_nelements(dst); - const int dr = (ne + nth - 1) / nth; - const int ie0 = dr * ith; - const int ie1 = std::min(ie0 + dr, ne); - - for (int i = ie0; i < ie1; ++i) { - dst_data[i] = sinf(src_data[i]); - } -} - -static void ggml_sam_cos(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) { - GGML_ASSERT(userdata == NULL); - GGML_ASSERT(ggml_are_same_shape(dst, src)); - GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_is_contiguous(src)); - - const float * src_data = ggml_get_data_f32(src); - float * dst_data = ggml_get_data_f32(dst); - - const int ne = (int)ggml_nelements(dst); - const int dr = (ne + nth - 1) / nth; - const int ie0 = dr * ith; - const int ie1 = std::min(ie0 + dr, ne); - - for (int i = ie0; i < ie1; ++i) { - dst_data[i] = cosf(src_data[i]); - } -} - -bool sam_image_load_from_file(const std::string & fname, sam_image_u8 & img) { - int nx, ny, nc; - auto data = stbi_load(fname.c_str(), &nx, &ny, &nc, 3); - if (!data) { - fprintf(stderr, "%s: failed to load '%s'\n", __func__, fname.c_str()); - return false; - } - - img.nx = nx; - img.ny = ny; - img.data.resize(nx * ny * 3); - memcpy(img.data.data(), data, nx * ny * 3); - - stbi_image_free(data); - - return true; -} - -// ref: https://github.com/facebookresearch/segment-anything/blob/efeab7296ab579d4a261e554eca80faf6b33924a/segment_anything/modeling/sam.py#L164 -// resize largest dimension to 1024 -// normalize: x = (x - mean) / std -// mean = [123.675, 116.28, 103.53] -// std = [58.395, 57.12, 57.375] -// TODO: why are these hardcoded !? -// pad to 1024x1024 -// TODO: for some reason, this is not numerically identical to pytorch's interpolation -bool sam_image_preprocess(const sam_image_u8 & img, sam_image_f32 & res) { - const int nx = img.nx; - const int ny = img.ny; - - const int nx2 = 1024; - const int ny2 = 1024; - - res.nx = nx2; - res.ny = ny2; - res.data.resize(3*nx2*ny2); - - const float scale = std::max(nx, ny) / 1024.0f; - - fprintf(stderr, "%s: scale = %f\n", __func__, scale); - - const int nx3 = int(nx/scale + 0.5f); - const int ny3 = int(ny/scale + 0.5f); - - const float m3[3] = { 123.675f, 116.280f, 103.530f }; - const float s3[3] = { 58.395f, 57.120f, 57.375f }; - - for (int y = 0; y < ny3; y++) { - for (int x = 0; x < nx3; x++) { - for (int c = 0; c < 3; c++) { - // linear interpolation - const float sx = (x + 0.5f)*scale - 0.5f; - const float sy = (y + 0.5f)*scale - 0.5f; - - const int x0 = std::max(0, (int) std::floor(sx)); - const int y0 = std::max(0, (int) std::floor(sy)); - - const int x1 = std::min(x0 + 1, nx - 1); - const int y1 = std::min(y0 + 1, ny - 1); - - const float dx = sx - x0; - const float dy = sy - y0; - - const int j00 = 3*(y0*nx + x0) + c; - const int j01 = 3*(y0*nx + x1) + c; - const int j10 = 3*(y1*nx + x0) + c; - const int j11 = 3*(y1*nx + x1) + c; - - const float v00 = img.data[j00]; - const float v01 = img.data[j01]; - const float v10 = img.data[j10]; - const float v11 = img.data[j11]; - - const float v0 = v00*(1.0f - dx) + v01*dx; - const float v1 = v10*(1.0f - dx) + v11*dx; - - const float v = v0*(1.0f - dy) + v1*dy; - - const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f); - - const int i = 3*(y*nx3 + x) + c; - - res.data[i] = (float(v2) - m3[c]) / s3[c]; - } - } - } - - return true; -} - -// load the model's weights from a file -bool sam_model_load(const sam_params & params, sam_model & model) { - fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, params.model.c_str()); - - auto fin = std::ifstream(params.model, std::ios::binary); - if (!fin) { - fprintf(stderr, "%s: failed to open '%s'\n", __func__, params.model.c_str()); - return false; - } - - // verify magic - { - uint32_t magic; - fin.read((char *) &magic, sizeof(magic)); - if (magic != 0x67676d6c) { - fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, params.model.c_str()); - return false; - } - } - - // load hparams - { - // Override defaults with user choices - model.hparams.mask_threshold = params.mask_threshold; - model.hparams.iou_threshold = params.iou_threshold; - model.hparams.stability_score_threshold = params.stability_score_threshold; - model.hparams.stability_score_offset = params.stability_score_offset; - model.hparams.eps = params.eps; - model.hparams.eps_decoder_transformer = params.eps_decoder_transformer; - - auto & hparams = model.hparams; - - fin.read((char *) &hparams.n_enc_state, sizeof(hparams.n_enc_state)); - fin.read((char *) &hparams.n_enc_layer, sizeof(hparams.n_enc_layer)); - fin.read((char *) &hparams.n_enc_head, sizeof(hparams.n_enc_head)); - fin.read((char *) &hparams.n_enc_out_chans, sizeof(hparams.n_enc_out_chans)); - fin.read((char *) &hparams.n_pt_embd, sizeof(hparams.n_pt_embd)); - fin.read((char *) &hparams.ftype, sizeof(hparams.ftype)); - - const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR; - - printf("%s: n_enc_state = %d\n", __func__, hparams.n_enc_state); - printf("%s: n_enc_layer = %d\n", __func__, hparams.n_enc_layer); - printf("%s: n_enc_head = %d\n", __func__, hparams.n_enc_head); - printf("%s: n_enc_out_chans = %d\n", __func__, hparams.n_enc_out_chans); - printf("%s: n_pt_embd = %d\n", __func__, hparams.n_pt_embd); - printf("%s: ftype = %d\n", __func__, hparams.ftype); - printf("%s: qntvr = %d\n", __func__, qntvr); - - hparams.ftype %= GGML_QNT_VERSION_FACTOR; - - } - - // for the big tensors, we have the option to store the data in 16-bit floats or quantized - // in order to save memory and also to speed up the computation - ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype)); - if (wtype == GGML_TYPE_COUNT) { - fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n", - __func__, params.model.c_str(), model.hparams.ftype); - return false; - } - - auto & ctx = model.ctx; - - const size_t ctx_size = [&]() { - size_t ctx_size = 0; - - const auto & hparams = model.hparams; - - const int32_t n_enc_state = hparams.n_enc_state; - const int32_t n_enc_layer = hparams.n_enc_layer; - const int32_t n_enc_head_dim = hparams.n_enc_head_dim(); - const int32_t n_enc_out_chans = hparams.n_enc_out_chans; - const int32_t n_pt_embd = hparams.n_pt_embd; - - const int32_t n_enc_layer_local = hparams.global_attn_indices().size(); - const int32_t n_enc_layer_global = n_enc_layer - n_enc_layer_local; - - const int32_t n_img_embd = hparams.n_img_embd(); - const int32_t n_window_size = hparams.n_window_size(); - const int32_t n_patch_size = hparams.n_patch_size(); - - // image encoder - { - ctx_size += n_enc_state*n_img_embd*n_img_embd*ggml_type_size(GGML_TYPE_F32); - - ctx_size += n_enc_state*3*n_patch_size*n_patch_size*ggml_type_size(GGML_TYPE_F16); - ctx_size += n_enc_state*ggml_type_size(GGML_TYPE_F32); - - ctx_size += n_enc_state*n_enc_out_chans*1*1*ggml_type_size(GGML_TYPE_F16); - ctx_size += n_enc_out_chans*n_enc_out_chans*3*3*ggml_type_size(GGML_TYPE_F16); - - ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F32); - ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F32); - - ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F32); - ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F32); - } - - // image encoder layers - { - ctx_size += n_enc_layer*n_enc_state*ggml_type_size(GGML_TYPE_F32); - ctx_size += n_enc_layer*n_enc_state*ggml_type_size(GGML_TYPE_F32); - - ctx_size += n_enc_layer_global*n_enc_head_dim*(2*n_img_embd - 1)*ggml_type_size(GGML_TYPE_F16); - ctx_size += n_enc_layer_global*n_enc_head_dim*(2*n_img_embd - 1)*ggml_type_size(GGML_TYPE_F16); - - ctx_size += n_enc_layer_local*n_enc_head_dim*(2*n_window_size - 1)*ggml_type_size(GGML_TYPE_F16); - ctx_size += n_enc_layer_local*n_enc_head_dim*(2*n_window_size - 1)*ggml_type_size(GGML_TYPE_F16); - - ctx_size += n_enc_layer*3*n_enc_state*n_enc_state*ggml_type_size(GGML_TYPE_F16); - ctx_size += n_enc_layer*3*n_enc_state* ggml_type_size(GGML_TYPE_F32); - - ctx_size += n_enc_layer*n_enc_state*n_enc_state*ggml_type_size(GGML_TYPE_F16); - ctx_size += n_enc_layer*n_enc_state* ggml_type_size(GGML_TYPE_F32); - - ctx_size += n_enc_layer*n_enc_state*ggml_type_size(GGML_TYPE_F32); - ctx_size += n_enc_layer*n_enc_state*ggml_type_size(GGML_TYPE_F32); - - ctx_size += n_enc_layer*4*n_enc_state*n_enc_state*ggml_type_size(GGML_TYPE_F16); - ctx_size += n_enc_layer*4*n_enc_state* ggml_type_size(GGML_TYPE_F32); - - ctx_size += n_enc_layer*4*n_enc_state*n_enc_state*ggml_type_size(GGML_TYPE_F16); - ctx_size += n_enc_layer*4*n_enc_state* ggml_type_size(GGML_TYPE_F32); - } - - ctx_size += (8 + 14*n_enc_layer)*ggml_tensor_overhead(); - - // prompt encoder - { - ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F16); // 2*(n_enc_out_chans/2) - - ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F32); - ctx_size += n_pt_embd*n_enc_out_chans*ggml_type_size(GGML_TYPE_F32); - } - - ctx_size += (2 + n_pt_embd)*ggml_tensor_overhead(); - - // mask decoder - { - //transformer - { - const int tfm_layers_count = 2; - const int qkv_count = 3; - const int norm_count = 4; - const int n_hypernet_mpls_count = 4; - - // self_attn - ctx_size += tfm_layers_count*qkv_count*n_enc_state*n_enc_state*ggml_type_size(GGML_TYPE_F16); - ctx_size += tfm_layers_count*qkv_count*n_enc_state* ggml_type_size(GGML_TYPE_F32); - ctx_size += tfm_layers_count*n_enc_state* ggml_type_size(GGML_TYPE_F32); - - // all norms - ctx_size += tfm_layers_count*norm_count*n_enc_state*ggml_type_size(GGML_TYPE_F32); - ctx_size += tfm_layers_count*norm_count*n_enc_state*ggml_type_size(GGML_TYPE_F32); - - // cross_attn_token_to_img - ctx_size += tfm_layers_count*qkv_count*n_enc_state*(n_enc_state/2)*ggml_type_size(GGML_TYPE_F16); - ctx_size += tfm_layers_count*qkv_count*(n_enc_state/2)* ggml_type_size(GGML_TYPE_F32); - ctx_size += tfm_layers_count*n_enc_state* ggml_type_size(GGML_TYPE_F32); - - // mlp - ctx_size += tfm_layers_count*8*n_enc_out_chans*n_enc_out_chans*ggml_type_size(GGML_TYPE_F16); - ctx_size += tfm_layers_count*8*n_enc_out_chans* ggml_type_size(GGML_TYPE_F32); - ctx_size += tfm_layers_count*n_enc_out_chans*8*n_enc_out_chans*ggml_type_size(GGML_TYPE_F16); - ctx_size += tfm_layers_count*n_enc_out_chans* ggml_type_size(GGML_TYPE_F32); - - // cross_attn_img_to_token - ctx_size += tfm_layers_count*qkv_count*n_enc_state*(n_enc_state/2)*ggml_type_size(GGML_TYPE_F16); - ctx_size += tfm_layers_count*qkv_count*(n_enc_state/2)* ggml_type_size(GGML_TYPE_F32); - ctx_size += tfm_layers_count*n_enc_state* ggml_type_size(GGML_TYPE_F32); - - // transformer_final_attn_token_to_img - ctx_size += qkv_count*n_enc_state*(n_enc_state/2)*ggml_type_size(GGML_TYPE_F16); - ctx_size += qkv_count*(n_enc_state/2)* ggml_type_size(GGML_TYPE_F32); - ctx_size += n_enc_state* ggml_type_size(GGML_TYPE_F32); - - // transformer_norm_final - ctx_size += norm_count*n_enc_state*ggml_type_size(GGML_TYPE_F32); - ctx_size += norm_count*n_enc_state*ggml_type_size(GGML_TYPE_F32); - - // output_upscaling - ctx_size += n_enc_out_chans*n_img_embd*2*2*ggml_type_size(GGML_TYPE_F16); - ctx_size += 3*n_img_embd* ggml_type_size(GGML_TYPE_F32); - ctx_size += n_enc_out_chans*n_img_embd*(n_img_embd/2)*2*2*ggml_type_size(GGML_TYPE_F16); - ctx_size += (n_img_embd/2)* ggml_type_size(GGML_TYPE_F32); - - // output_hypernetworks_mlps - ctx_size += n_hypernet_mpls_count*2*n_enc_out_chans*n_enc_out_chans*ggml_type_size(GGML_TYPE_F16); - ctx_size += n_hypernet_mpls_count*2*n_enc_out_chans* ggml_type_size(GGML_TYPE_F32); - ctx_size += n_hypernet_mpls_count*n_enc_out_chans*(n_img_embd/2)*ggml_type_size(GGML_TYPE_F16); - ctx_size += n_hypernet_mpls_count*(n_img_embd/2)* ggml_type_size(GGML_TYPE_F32); - - // iou_prediction_head - ctx_size += 2*n_enc_out_chans*n_enc_out_chans*ggml_type_size(GGML_TYPE_F16); - ctx_size += 2*n_enc_out_chans* ggml_type_size(GGML_TYPE_F32); - ctx_size += n_pt_embd*n_enc_out_chans*ggml_type_size(GGML_TYPE_F16); - ctx_size += n_pt_embd* ggml_type_size(GGML_TYPE_F32); - - // iou_token_w - ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F32); - - // mask_tokens_w - ctx_size += n_pt_embd*n_enc_out_chans*ggml_type_size(GGML_TYPE_F32); - } - } - fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); - - return ctx_size; - }(); - - // create the ggml context - { - struct ggml_init_params params = { - /*.mem_size =*/ ctx_size, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ false, - }; - - ctx = ggml_init(params); - if (!ctx) { - fprintf(stderr, "%s: ggml_init() failed\n", __func__); - return false; - } - } - - // prepare memory for the weights - { - const auto & hparams = model.hparams; - - const int32_t n_enc_state = hparams.n_enc_state; - const int32_t n_enc_layer = hparams.n_enc_layer; - const int32_t n_enc_head_dim = hparams.n_enc_head_dim(); - const int32_t n_enc_out_chans = hparams.n_enc_out_chans; - const int32_t n_pt_embd = hparams.n_pt_embd; - - const int32_t n_img_embd = hparams.n_img_embd(); - const int32_t n_window_size = hparams.n_window_size(); - const int32_t n_patch_size = hparams.n_patch_size(); - - model.enc_img.layers.resize(n_enc_layer); - - // image encoder - { - auto & enc = model.enc_img; - - enc.pe = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_enc_state, n_img_embd, n_img_embd, 1); - - enc.proj_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, n_patch_size, n_patch_size, 3, n_enc_state); - enc.proj_b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 1, 1, n_enc_state); - - enc.neck_conv_0 = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, n_enc_state, n_enc_out_chans); - enc.neck_conv_1 = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, n_enc_out_chans, n_enc_out_chans); - - enc.neck_norm_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - enc.neck_norm_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - - enc.neck_norm_1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - enc.neck_norm_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - - model.tensors["image_encoder.pos_embed"] = enc.pe; - - model.tensors["image_encoder.patch_embed.proj.weight"] = enc.proj_w; - model.tensors["image_encoder.patch_embed.proj.bias"] = enc.proj_b; - - model.tensors["image_encoder.neck.0.weight"] = enc.neck_conv_0; - model.tensors["image_encoder.neck.2.weight"] = enc.neck_conv_1; - - model.tensors["image_encoder.neck.1.weight"] = enc.neck_norm_0_w; - model.tensors["image_encoder.neck.1.bias"] = enc.neck_norm_0_b; - - model.tensors["image_encoder.neck.3.weight"] = enc.neck_norm_1_w; - model.tensors["image_encoder.neck.3.bias"] = enc.neck_norm_1_b; - - for (int i = 0; i < n_enc_layer; ++i) { - auto & layer = enc.layers[i]; - - layer.norm1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state); - layer.norm1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state); - - if (hparams.is_global_attn(i)) { - layer.rel_pos_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_head_dim, 2*n_img_embd - 1); - layer.rel_pos_h = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_head_dim, 2*n_img_embd - 1); - } else { - layer.rel_pos_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_head_dim, 2*n_window_size - 1); - layer.rel_pos_h = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_head_dim, 2*n_window_size - 1); - } - - layer.qkv_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_state, 3*n_enc_state); - layer.qkv_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_enc_state); - - layer.proj_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_state, n_enc_state); - layer.proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state); - - layer.norm2_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state); - layer.norm2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state); - - layer.mlp_lin1_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_state, 4*n_enc_state); - layer.mlp_lin1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_enc_state); - - layer.mlp_lin2_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, 4*n_enc_state, n_enc_state); - layer.mlp_lin2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state); - - model.tensors["image_encoder.blocks." + std::to_string(i) + ".norm1.weight"] = layer.norm1_w; - model.tensors["image_encoder.blocks." + std::to_string(i) + ".norm1.bias"] = layer.norm1_b; - - model.tensors["image_encoder.blocks." + std::to_string(i) + ".attn.rel_pos_w"] = layer.rel_pos_w; - model.tensors["image_encoder.blocks." + std::to_string(i) + ".attn.rel_pos_h"] = layer.rel_pos_h; - - model.tensors["image_encoder.blocks." + std::to_string(i) + ".attn.qkv.weight"] = layer.qkv_w; - model.tensors["image_encoder.blocks." + std::to_string(i) + ".attn.qkv.bias"] = layer.qkv_b; - - model.tensors["image_encoder.blocks." + std::to_string(i) + ".attn.proj.weight"] = layer.proj_w; - model.tensors["image_encoder.blocks." + std::to_string(i) + ".attn.proj.bias"] = layer.proj_b; - - model.tensors["image_encoder.blocks." + std::to_string(i) + ".norm2.weight"] = layer.norm2_w; - model.tensors["image_encoder.blocks." + std::to_string(i) + ".norm2.bias"] = layer.norm2_b; - - model.tensors["image_encoder.blocks." + std::to_string(i) + ".mlp.lin1.weight"] = layer.mlp_lin1_w; - model.tensors["image_encoder.blocks." + std::to_string(i) + ".mlp.lin1.bias"] = layer.mlp_lin1_b; - - model.tensors["image_encoder.blocks." + std::to_string(i) + ".mlp.lin2.weight"] = layer.mlp_lin2_w; - model.tensors["image_encoder.blocks." + std::to_string(i) + ".mlp.lin2.bias"] = layer.mlp_lin2_b; - } - } - - // prompt encoder - { - auto & enc = model.enc_prompt; - - enc.pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_enc_out_chans/2, 2); - - enc.not_a_pt_embd_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - enc.no_mask_embd_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - - model.tensors["prompt_encoder.pe_layer.positional_encoding_gaussian_matrix"] = enc.pe; - model.tensors["prompt_encoder.not_a_point_embed.weight"] = enc.not_a_pt_embd_w; - model.tensors["prompt_encoder.no_mask_embed.weight"] = enc.no_mask_embd_w; - - enc.pt_embd.resize(n_pt_embd); - for (int i = 0; i < n_pt_embd; i++) { - enc.pt_embd[i] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - - model.tensors["prompt_encoder.point_embeddings." + std::to_string(i) + ".weight"] = enc.pt_embd[i]; - } - } - - // mask decoder - { - auto & dec = model.dec; - auto & tfm_layers = dec.transformer_layers; - - const int tfm_layers_count = 2; - tfm_layers.resize(tfm_layers_count); - for (int i = 0; i < tfm_layers_count; ++i) { - auto& l = tfm_layers[i]; - l.self_attn.q_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans); - l.self_attn.q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - l.self_attn.k_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans); - l.self_attn.k_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - l.self_attn.v_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans); - l.self_attn.v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - l.self_attn.out_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans); - l.self_attn.out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - - l.norm1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - l.norm1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - - l.cross_attn_token_to_img.q_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2); - l.cross_attn_token_to_img.q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2); - l.cross_attn_token_to_img.k_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2); - l.cross_attn_token_to_img.k_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2); - l.cross_attn_token_to_img.v_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2); - l.cross_attn_token_to_img.v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2); - l.cross_attn_token_to_img.out_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans/2, n_enc_out_chans); - l.cross_attn_token_to_img.out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - - l.norm2_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - l.norm2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - - l.mlp_lin1_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, 8*n_enc_out_chans); - l.mlp_lin1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 8*n_enc_out_chans); - l.mlp_lin2_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, 8*n_enc_out_chans, n_enc_out_chans); - l.mlp_lin2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - - l.norm3_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - l.norm3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - - l.norm4_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - l.norm4_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - - l.cross_attn_img_to_token.q_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2); - l.cross_attn_img_to_token.q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2); - l.cross_attn_img_to_token.k_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2); - l.cross_attn_img_to_token.k_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2); - l.cross_attn_img_to_token.v_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2); - l.cross_attn_img_to_token.v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2); - l.cross_attn_img_to_token.out_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans/2, n_enc_out_chans); - l.cross_attn_img_to_token.out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - - const auto prefix = "mask_decoder.transformer.layers." + std::to_string(i) + "."; - model.tensors[prefix + "self_attn.q_proj.weight"] = l.self_attn.q_w; - model.tensors[prefix + "self_attn.q_proj.bias"] = l.self_attn.q_b; - model.tensors[prefix + "self_attn.k_proj.weight"] = l.self_attn.k_w; - model.tensors[prefix + "self_attn.k_proj.bias"] = l.self_attn.k_b; - model.tensors[prefix + "self_attn.v_proj.weight"] = l.self_attn.v_w; - model.tensors[prefix + "self_attn.v_proj.bias"] = l.self_attn.v_b; - model.tensors[prefix + "self_attn.out_proj.weight"] = l.self_attn.out_w; - model.tensors[prefix + "self_attn.out_proj.bias"] = l.self_attn.out_b; - - model.tensors[prefix + "norm1.weight"] = l.norm1_w; - model.tensors[prefix + "norm1.bias"] = l.norm1_b; - - model.tensors[prefix + "cross_attn_token_to_image.q_proj.weight"] = l.cross_attn_token_to_img.q_w; - model.tensors[prefix + "cross_attn_token_to_image.q_proj.bias"] = l.cross_attn_token_to_img.q_b; - model.tensors[prefix + "cross_attn_token_to_image.k_proj.weight"] = l.cross_attn_token_to_img.k_w; - model.tensors[prefix + "cross_attn_token_to_image.k_proj.bias"] = l.cross_attn_token_to_img.k_b; - model.tensors[prefix + "cross_attn_token_to_image.v_proj.weight"] = l.cross_attn_token_to_img.v_w; - model.tensors[prefix + "cross_attn_token_to_image.v_proj.bias"] = l.cross_attn_token_to_img.v_b; - model.tensors[prefix + "cross_attn_token_to_image.out_proj.weight"] = l.cross_attn_token_to_img.out_w; - model.tensors[prefix + "cross_attn_token_to_image.out_proj.bias"] = l.cross_attn_token_to_img.out_b; - - model.tensors[prefix + "norm2.weight"] = l.norm2_w; - model.tensors[prefix + "norm2.bias"] = l.norm2_b; - - model.tensors[prefix + "mlp.lin1.weight"] = l.mlp_lin1_w; - model.tensors[prefix + "mlp.lin1.bias"] = l.mlp_lin1_b; - model.tensors[prefix + "mlp.lin2.weight"] = l.mlp_lin2_w; - model.tensors[prefix + "mlp.lin2.bias"] = l.mlp_lin2_b; - - model.tensors[prefix + "norm3.weight"] = l.norm3_w; - model.tensors[prefix + "norm3.bias"] = l.norm3_b; - model.tensors[prefix + "norm4.weight"] = l.norm4_w; - model.tensors[prefix + "norm4.bias"] = l.norm4_b; - - model.tensors[prefix + "cross_attn_image_to_token.q_proj.weight"] = l.cross_attn_img_to_token.q_w; - model.tensors[prefix + "cross_attn_image_to_token.q_proj.bias"] = l.cross_attn_img_to_token.q_b; - model.tensors[prefix + "cross_attn_image_to_token.k_proj.weight"] = l.cross_attn_img_to_token.k_w; - model.tensors[prefix + "cross_attn_image_to_token.k_proj.bias"] = l.cross_attn_img_to_token.k_b; - model.tensors[prefix + "cross_attn_image_to_token.v_proj.weight"] = l.cross_attn_img_to_token.v_w; - model.tensors[prefix + "cross_attn_image_to_token.v_proj.bias"] = l.cross_attn_img_to_token.v_b; - model.tensors[prefix + "cross_attn_image_to_token.out_proj.weight"] = l.cross_attn_img_to_token.out_w; - model.tensors[prefix + "cross_attn_image_to_token.out_proj.bias"] = l.cross_attn_img_to_token.out_b; - } - - dec.transformer_final_attn_token_to_img.q_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2); - dec.transformer_final_attn_token_to_img.q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2); - dec.transformer_final_attn_token_to_img.k_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2); - dec.transformer_final_attn_token_to_img.k_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2); - dec.transformer_final_attn_token_to_img.v_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2); - dec.transformer_final_attn_token_to_img.v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2); - dec.transformer_final_attn_token_to_img.out_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans/2, n_enc_out_chans); - dec.transformer_final_attn_token_to_img.out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - - model.tensors["mask_decoder.transformer.final_attn_token_to_image.q_proj.weight"] = dec.transformer_final_attn_token_to_img.q_w; - model.tensors["mask_decoder.transformer.final_attn_token_to_image.q_proj.bias"] = dec.transformer_final_attn_token_to_img.q_b; - model.tensors["mask_decoder.transformer.final_attn_token_to_image.k_proj.weight"] = dec.transformer_final_attn_token_to_img.k_w; - model.tensors["mask_decoder.transformer.final_attn_token_to_image.k_proj.bias"] = dec.transformer_final_attn_token_to_img.k_b; - model.tensors["mask_decoder.transformer.final_attn_token_to_image.v_proj.weight"] = dec.transformer_final_attn_token_to_img.v_w; - model.tensors["mask_decoder.transformer.final_attn_token_to_image.v_proj.bias"] = dec.transformer_final_attn_token_to_img.v_b; - model.tensors["mask_decoder.transformer.final_attn_token_to_image.out_proj.weight"] = dec.transformer_final_attn_token_to_img.out_w; - model.tensors["mask_decoder.transformer.final_attn_token_to_image.out_proj.bias"] = dec.transformer_final_attn_token_to_img.out_b; - - dec.transformer_norm_final_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - dec.transformer_norm_final_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - - model.tensors["mask_decoder.transformer.norm_final_attn.weight"] = dec.transformer_norm_final_w; - model.tensors["mask_decoder.transformer.norm_final_attn.bias"] = dec.transformer_norm_final_b; - - dec.output_upscaling_0_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 2, 2, n_img_embd, n_enc_out_chans); - dec.output_upscaling_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_img_embd); - dec.output_upscaling_1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_img_embd); - dec.output_upscaling_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_img_embd); - dec.output_upscaling_3_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 2, 2, n_img_embd/2, n_img_embd); - dec.output_upscaling_3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_img_embd/2); - - model.tensors["mask_decoder.output_upscaling.0.weight"] = dec.output_upscaling_0_w; - model.tensors["mask_decoder.output_upscaling.0.bias"] = dec.output_upscaling_0_b; - model.tensors["mask_decoder.output_upscaling.1.weight"] = dec.output_upscaling_1_w; - model.tensors["mask_decoder.output_upscaling.1.bias"] = dec.output_upscaling_1_b; - model.tensors["mask_decoder.output_upscaling.3.weight"] = dec.output_upscaling_3_w; - model.tensors["mask_decoder.output_upscaling.3.bias"] = dec.output_upscaling_3_b; - - const int n_hypernet_mpls_count = 4; - dec.output_hypernet_mlps.resize(n_hypernet_mpls_count); - for (int i = 0; i < n_hypernet_mpls_count; ++i) { - auto& mlp = dec.output_hypernet_mlps[i]; - - mlp.w_0 = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans); - mlp.b_0 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - mlp.w_1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans); - mlp.b_1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - mlp.w_2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_img_embd/2); - mlp.b_2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_img_embd/2); - - const auto prefix = "mask_decoder.output_hypernetworks_mlps." + std::to_string(i) + "."; - model.tensors[prefix + "layers.0.weight"] = mlp.w_0; - model.tensors[prefix + "layers.0.bias"] = mlp.b_0; - model.tensors[prefix + "layers.1.weight"] = mlp.w_1; - model.tensors[prefix + "layers.1.bias"] = mlp.b_1; - model.tensors[prefix + "layers.2.weight"] = mlp.w_2; - model.tensors[prefix + "layers.2.bias"] = mlp.b_2; - } - - dec.iou_prediction_head_0_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans); - dec.iou_prediction_head_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - dec.iou_prediction_head_1_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans); - dec.iou_prediction_head_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); - dec.iou_prediction_head_2_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_pt_embd); - dec.iou_prediction_head_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pt_embd); - - dec.iou_token_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_enc_out_chans, 1); - dec.mask_tokens_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_enc_out_chans, n_pt_embd); - - model.tensors["mask_decoder.iou_prediction_head.layers.0.weight"] = dec.iou_prediction_head_0_w; - model.tensors["mask_decoder.iou_prediction_head.layers.0.bias"] = dec.iou_prediction_head_0_b; - model.tensors["mask_decoder.iou_prediction_head.layers.1.weight"] = dec.iou_prediction_head_1_w; - model.tensors["mask_decoder.iou_prediction_head.layers.1.bias"] = dec.iou_prediction_head_1_b; - model.tensors["mask_decoder.iou_prediction_head.layers.2.weight"] = dec.iou_prediction_head_2_w; - model.tensors["mask_decoder.iou_prediction_head.layers.2.bias"] = dec.iou_prediction_head_2_b; - - model.tensors["mask_decoder.iou_token.weight"] = dec.iou_token_w; - model.tensors["mask_decoder.mask_tokens.weight"] = dec.mask_tokens_w; - } - } - - // load weights - { - int n_tensors = 0; - size_t total_size = 0; - - fprintf(stderr, "%s: ", __func__); - - while (true) { - int32_t n_dims; - int32_t length; - int32_t ftype; - - fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); - fin.read(reinterpret_cast(&length), sizeof(length)); - fin.read(reinterpret_cast(&ftype), sizeof(ftype)); - - if (fin.eof()) { - break; - } - - int64_t nelements = 1; - int64_t ne[4] = { 1, 1, 1, 1 }; - for (int i = 0; i < n_dims; ++i) { - int32_t ne_cur; - fin.read(reinterpret_cast(&ne_cur), sizeof(ne_cur)); - ne[i] = ne_cur; - nelements *= ne[i]; - } - - std::string name(length, 0); - fin.read(&name[0], length); - - if (model.tensors.find(name.data()) == model.tensors.end()) { - fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); - return false; - } - - auto tensor = model.tensors[name.data()]; - //printf("ne0 = %jd, ne1 = %jd, ne2 = %jd, ne3 = %jd\n", ne[0], ne[1], ne[2], ne[3]); - - if (ggml_nelements(tensor) != nelements) { - fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %d, expected %d\n", - __func__, name.data(), (int) nelements, (int) ggml_nelements(tensor)); - return false; - } - - if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2] || tensor->ne[3] != ne[3]) { - fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d, %d], expected [%d, %d, %d, %d]\n", - __func__, name.data(), - (int) ne[0], (int) ne[1], (int) ne[2], (int) ne[3], - (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], (int) tensor->ne[3]); - return false; - } - - size_t bpe = 0; - - switch (ftype) { - case 0: bpe = ggml_type_size(GGML_TYPE_F32); break; - case 1: bpe = ggml_type_size(GGML_TYPE_F16); break; - case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break; - case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break; - default: - { - fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype); - return false; - } - }; - - if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { - fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", - __func__, name.data(), ggml_nbytes(tensor), (size_t) nelements*bpe); - return false; - } - - fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); - - total_size += ggml_nbytes(tensor); - if (++n_tensors % 8 == 0) { - fprintf(stderr, "."); - fflush(stdout); - } - } - - if (n_tensors != int(model.tensors.size())) { - fprintf(stderr, "%s: model file has %d tensors, but %d tensors were expected\n", __func__, n_tensors, (int) model.tensors.size()); - return false; - } - - fprintf(stderr, " done\n"); - - fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors); - } - - fin.close(); - - return true; -} - -struct ggml_tensor * sam_fill_dense_pe( - const sam_model & model, - struct ggml_context * ctx0, - struct ggml_cgraph * gf, - sam_state & state) { - const auto & hparams = model.hparams; - const auto & enc = model.enc_prompt; - - - const int32_t n_img_embd = hparams.n_img_embd(); - struct ggml_tensor * xy_embed_stacked = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 2, n_img_embd, n_img_embd); - ggml_set_name(xy_embed_stacked, "xy_embed_stacked"); - ggml_set_input(xy_embed_stacked); - - struct ggml_tensor * cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, enc.pe)), xy_embed_stacked); - - cur = ggml_scale(ctx0, cur, float(2.0*M_PI)); - - // concat - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L192 - { - struct ggml_tensor * t_sin = ggml_map_custom1(ctx0, cur, ggml_sam_sin, GGML_N_TASKS_MAX, NULL); - struct ggml_tensor * t_cos = ggml_map_custom1(ctx0, cur, ggml_sam_cos, GGML_N_TASKS_MAX, NULL); - - cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, t_sin->ne[0] + t_cos->ne[0], cur->ne[1], cur->ne[2]); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_sin, ggml_view_3d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], t_sin->ne[2], cur->nb[1], cur->nb[2], 0))); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_cos, ggml_view_3d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], t_sin->ne[2], cur->nb[1], cur->nb[2], t_sin->nb[1]))); - } - - struct ggml_tensor * pe_img_dense = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); - ggml_build_forward_expand(gf, pe_img_dense); - - return pe_img_dense; -} - -struct ggml_tensor* sam_layer_norm_2d( - struct ggml_context * ctx0, - struct ggml_tensor * layer, - int n_channels, - struct ggml_tensor * w, - struct ggml_tensor * b, - float eps) { - // LayerNorm2d - // normalize along channel dimmension - // TODO: better implementation - layer = ggml_permute(ctx0, - ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, layer, 1, 2, 0, 3)), eps), - 2, 0, 1, 3); - - layer = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, ggml_reshape_3d(ctx0, w, 1, 1, n_channels), layer), - layer), - ggml_repeat(ctx0, ggml_reshape_3d(ctx0, b, 1, 1, n_channels), layer)); - - return layer; -} - -struct ggml_cgraph * sam_encode_image( - const sam_model & model, - sam_state & state, - const sam_image_f32 & img) { - - const auto & hparams = model.hparams; - const auto & enc = model.enc_img; - - const int32_t n_enc_state = hparams.n_enc_state; - const int32_t n_enc_layer = hparams.n_enc_layer; - const int32_t n_enc_head = hparams.n_enc_head; - const int32_t n_enc_head_dim = hparams.n_enc_head_dim(); - const int32_t n_enc_out_chans = hparams.n_enc_out_chans; - const int32_t n_img_size = hparams.n_img_size(); - const int32_t n_window_size = hparams.n_window_size(); - - struct ggml_init_params ggml_params = { - /*.mem_size =*/ state.buf_compute_img_enc.size(), - /*.mem_buffer =*/ state.buf_compute_img_enc.data(), - /*.no_alloc =*/ true, // skip allocating as we use ggml_alloc to allocate exact memory requirements - }; - - struct ggml_context * ctx0 = ggml_init(ggml_params); - struct ggml_cgraph * gf = ggml_new_graph(ctx0); - - struct ggml_tensor * inp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_img_size, n_img_size, 3, 1); - ggml_set_name(inp, "inp"); - ggml_set_input(inp); - - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L392 - struct ggml_tensor * cur = ggml_conv_2d_sk_p0(ctx0, enc.proj_w, inp); - cur = ggml_add_inplace(ctx0, - cur, - ggml_repeat(ctx0, enc.proj_b, cur)); - - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L394 - // keep in F32 - cur = ggml_cont(ctx0, - ggml_permute(ctx0, cur, 1, 2, 0, 3)); - - // convert to F16 - //cur = ggml_cpy(ctx0, - // ggml_permute(ctx0, cur, 1, 2, 0, 3), - // ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_enc_state, n_img_embd, n_img_embd)); - - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L108-L109 - cur = ggml_add_inplace(ctx0, cur, enc.pe); - - struct ggml_tensor * inpL = cur; - - for (int il = 0; il < n_enc_layer; ++il) { - const auto & layer = enc.layers[il]; - - // norm - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L168 - { - cur = ggml_norm(ctx0, inpL, hparams.eps); - - // cur = ln_0_w*cur + ln_0_b - cur = ggml_mul(ctx0, cur, layer.norm1_w); - cur = ggml_add_inplace(ctx0, cur, layer.norm1_b); - } - - const int64_t w0 = cur->ne[1]; - const int64_t h0 = cur->ne[2]; - - if (hparams.is_global_attn(il) == false) { - // local attention layer - apply window partition - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L169-L172 - cur = ggml_win_part(ctx0, cur, n_window_size); - } - - const int64_t W = cur->ne[1]; - const int64_t H = cur->ne[2]; - - // self-attention - { - cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); - cur = ggml_add_inplace(ctx0, cur, layer.qkv_b); - - // split qkv into separate tensors - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L225-L229 - const int B = cur->ne[3]; - - cur = ggml_reshape_4d(ctx0, cur, n_enc_state, 3, W*H, B); - cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 3, 1, 2)); - - struct ggml_tensor * Q; - struct ggml_tensor * K; - struct ggml_tensor * V; - - Q = ggml_view_3d (ctx0, cur, n_enc_state, W*H, B, cur->nb[1], cur->nb[2], 0*cur->nb[3]); - Q = ggml_reshape_4d(ctx0, Q, n_enc_head_dim, n_enc_head, W*H, B); - Q = ggml_cont (ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); - Q = ggml_reshape_3d(ctx0, Q, n_enc_head_dim, W*H, B*n_enc_head); - - K = ggml_view_3d (ctx0, cur, n_enc_state, W*H, B, cur->nb[1], cur->nb[2], 1*cur->nb[3]); - K = ggml_reshape_4d(ctx0, K, n_enc_head_dim, n_enc_head, W*H, B); - K = ggml_cont (ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); - K = ggml_reshape_3d(ctx0, K, n_enc_head_dim, W*H, B*n_enc_head); - - V = ggml_view_3d (ctx0, cur, n_enc_state, W*H, B, cur->nb[1], cur->nb[2], 2*cur->nb[3]); - V = ggml_reshape_4d(ctx0, V, n_enc_head_dim, n_enc_head, W*H, B); - V = ggml_cont (ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); // transposed - V = ggml_reshape_3d(ctx0, V, W*H, n_enc_head_dim, B*n_enc_head); - - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - - struct ggml_tensor * KQ_scaled = - ggml_scale_inplace(ctx0, - KQ, - 1.0f/sqrtf(n_enc_head_dim)); - - struct ggml_tensor * rw = ggml_get_rel_pos(ctx0, layer.rel_pos_w, W, W); - struct ggml_tensor * rh = ggml_get_rel_pos(ctx0, layer.rel_pos_h, H, H); - - struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Q, n_enc_head_dim, W, H, B*n_enc_head); - - struct ggml_tensor * rel_w = ggml_cont(ctx0, ggml_permute(ctx0, - ggml_mul_mat(ctx0, - rw, - ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))), - 0, 2, 1, 3)); - struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r); - - struct ggml_tensor * attn = ggml_add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h); - - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn); - - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - - cur = - ggml_reshape_4d(ctx0, - ggml_cont(ctx0, - ggml_permute(ctx0, - ggml_reshape_4d(ctx0, KQV, n_enc_head_dim, W*H, n_enc_head, B), - 0, 2, 1, 3)), - n_enc_state, W, H, B); - - cur = ggml_mul_mat(ctx0, layer.proj_w, cur); - cur = ggml_add_inplace(ctx0, cur, layer.proj_b); - } - - if (hparams.is_global_attn(il) == false) { - // local attention layer - reverse window partition - cur = ggml_win_unpart(ctx0, cur, w0, h0, n_window_size); - } - - cur = ggml_add_inplace(ctx0, cur, inpL); - - struct ggml_tensor * inpFF = cur; - - // feed-forward network - { - // norm - { - cur = ggml_norm(ctx0, inpFF, hparams.eps); - - // cur = mlp_ln_w*cur + mlp_ln_b - cur = ggml_mul(ctx0, cur, layer.norm2_w); - cur = ggml_add_inplace(ctx0, cur, layer.norm2_b); - } - - // fully connected - cur = ggml_mul_mat(ctx0, layer.mlp_lin1_w, cur); - cur = ggml_add_inplace(ctx0, cur, layer.mlp_lin1_b); - - // GELU activation - cur = ggml_gelu(ctx0, cur); - - // projection - cur = ggml_mul_mat(ctx0, layer.mlp_lin2_w, cur); - cur = ggml_add_inplace(ctx0, cur, layer.mlp_lin2_b); - } - - inpL = ggml_add(ctx0, cur, inpFF); - } - - cur = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 2, 0, 1, 3)); - - cur = ggml_conv_2d_sk_p0(ctx0, enc.neck_conv_0, cur); - - cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_0_w, enc.neck_norm_0_b, hparams.eps); - - cur = ggml_conv_2d_s1_ph(ctx0, enc.neck_conv_1, cur); - - cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_1_w, enc.neck_norm_1_b, hparams.eps); - - cur = ggml_cpy(ctx0, cur, state.embd_img); - - ggml_build_forward_expand(gf, cur); - ggml_disconnect_node_from_graph(state.embd_img); - - //ggml_graph_print(&gf); - - ggml_free(ctx0); - - ggml_gallocr_alloc_graph(state.allocr, gf); - - { - struct ggml_tensor * inp = ggml_graph_get_tensor(gf, "inp"); - float * data = (float *) ggml_get_data(inp); - - const int nx = img.nx; - const int ny = img.ny; - const int n = nx*ny; - - GGML_ASSERT(nx == n_img_size && ny == n_img_size); - - for (int k = 0; k < 3; k++) { - for (int y = 0; y < ny; y++) { - for (int x = 0; x < nx; x++) { - data[k*n + y*nx + x] = img.data[3*(y*nx + x) + k]; - } - } - } - } - - return gf; -} - - -struct prompt_encoder_result { - struct ggml_tensor * embd_prompt_sparse = {}; - struct ggml_tensor * embd_prompt_dense = {}; -}; - -// encode a prompt -// -// - points -// - boxes -// - masks -// -// TODO: currently just encode a single point for simplicity -// -prompt_encoder_result sam_encode_prompt( - const sam_model & model, - struct ggml_context * ctx0, - struct ggml_cgraph * gf, - sam_state & state) { - - const auto & hparams = model.hparams; - const auto & enc = model.enc_prompt; - - struct ggml_tensor * inp = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2, 2); - ggml_set_name(inp, "prompt_input"); - ggml_set_input(inp); - - - struct ggml_tensor * cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, enc.pe)), inp); - - cur = ggml_scale(ctx0, cur, float(2.0*M_PI)); - - // concat - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L192 - { - struct ggml_tensor * t_sin = ggml_map_custom1(ctx0, cur, ggml_sam_sin, GGML_N_TASKS_MAX, NULL); - struct ggml_tensor * t_cos = ggml_map_custom1(ctx0, cur, ggml_sam_cos, GGML_N_TASKS_MAX, NULL); - - cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, t_sin->ne[0] + t_cos->ne[0], cur->ne[1]); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_sin, ggml_view_2d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], cur->nb[1], 0))); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_cos, ggml_view_2d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], cur->nb[1], t_sin->nb[1]))); - - // overwrite label == -1 with not_a_point_embed.weight - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L86 - // TODO: extend for multiple points - ggml_build_forward_expand(gf, ggml_cpy(ctx0, enc.not_a_pt_embd_w, ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], cur->nb[1]))); - } - - // add point_embeddings[1] to label == 1 - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L90 - struct ggml_tensor * v = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], 0); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_add_inplace(ctx0, v, enc.pt_embd[1]), v)); - - struct ggml_tensor * embd_prompt_sparse = cur; - ggml_build_forward_expand(gf, embd_prompt_sparse); - - struct ggml_tensor * embd_prompt_dense = ggml_repeat(ctx0, - ggml_cont(ctx0, - ggml_view_3d(ctx0, enc.no_mask_embd_w, - 1, 1, enc.no_mask_embd_w->ne[0], enc.no_mask_embd_w->nb[0], enc.no_mask_embd_w->nb[0], 0)), - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hparams.n_img_embd(), hparams.n_img_embd(), hparams.n_enc_out_chans)); - - ggml_build_forward_expand(gf, embd_prompt_dense); - - //printf("used_mem = %zu\n", ggml_used_mem(ctx0)); - - prompt_encoder_result res; - res.embd_prompt_sparse = embd_prompt_sparse; - res.embd_prompt_dense = embd_prompt_dense; - return res; -} - -struct ggml_tensor* sam_decode_mask_transformer_attn( - const sam_layer_dec_transformer_attn & attn, - struct ggml_tensor * queries, - struct ggml_tensor * keys, - struct ggml_tensor * values, - struct ggml_context * ctx0, - const sam_model & model) { - const auto & hparams = model.hparams; - const int n_head = hparams.n_dec_heads; - - struct ggml_tensor * Qcur = {}; - struct ggml_tensor * Kcur = {}; - struct ggml_tensor * Vcur = {}; - - Qcur = ggml_mul_mat(ctx0, attn.q_w, queries); - Qcur = ggml_add_inplace(ctx0, Qcur, attn.q_b); - - Kcur = ggml_mul_mat(ctx0, attn.k_w, keys); - Kcur = ggml_add_inplace(ctx0, Kcur, attn.k_b); - - Vcur = ggml_mul_mat(ctx0, attn.v_w, values); - Vcur = ggml_add_inplace(ctx0, Vcur, attn.v_b); - - struct ggml_tensor * Q = {}; - struct ggml_tensor * K = {}; - struct ggml_tensor * V = {}; - - Q = ggml_reshape_4d(ctx0, Qcur, Qcur->ne[0]/n_head, n_head, Qcur->ne[1], Qcur->ne[2]); - Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); - - K = ggml_reshape_4d(ctx0, Kcur, Kcur->ne[0]/n_head, n_head, Kcur->ne[1], Kcur->ne[2]); - K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); - - V = ggml_reshape_4d(ctx0, Vcur, Vcur->ne[0]/n_head, n_head, Vcur->ne[1], Vcur->ne[2]); - V = ggml_cont(ctx0, ggml_permute(ctx0, V, 0, 2, 1, 3)); - - // Q * K - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f/sqrt(float(Q->ne[0]))); - - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_scaled); - - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, KQ_soft_max, ggml_cont(ctx0, ggml_transpose(ctx0, V))); - - struct ggml_tensor * KQV_merged = ggml_cont(ctx0, ggml_transpose(ctx0, KQV)); - KQV_merged = ggml_cont(ctx0, ggml_permute(ctx0, KQV_merged, 0, 2, 1, 3)); - KQV_merged = ggml_reshape_3d(ctx0, KQV_merged, KQV_merged->ne[0]*KQV_merged->ne[1], KQV_merged->ne[2], KQV_merged->ne[3]); - KQV_merged = ggml_mul_mat(ctx0, attn.out_w, KQV_merged); - KQV_merged = ggml_add_inplace(ctx0, KQV_merged, attn.out_b); - - return KQV_merged; -} - -struct ggml_tensor * sam_decode_mask_mlp_relu_3( - struct ggml_tensor * in, - struct ggml_tensor * w_0, - struct ggml_tensor * b_0, - struct ggml_tensor * w_1, - struct ggml_tensor * b_1, - struct ggml_tensor * w_2, - struct ggml_tensor * b_2, - struct ggml_context * ctx0) { - - struct ggml_tensor * cur = {}; - cur = ggml_mul_mat(ctx0, w_0, in); - cur = ggml_add_inplace(ctx0, cur, b_0); - - cur = ggml_relu_inplace(ctx0, cur); - - cur = ggml_mul_mat(ctx0, w_1, cur); - cur = ggml_add_inplace(ctx0, cur, b_1); - - cur = ggml_relu_inplace(ctx0, cur); - - cur = ggml_mul_mat(ctx0, w_2, cur); - cur = ggml_add_inplace(ctx0, cur, b_2); - - return cur; -} - -bool sam_decode_mask( - const sam_model & model, - const prompt_encoder_result & prompt, - struct ggml_tensor * pe_img, - struct ggml_context * ctx0, - struct ggml_cgraph * gf, - sam_state & state) { - - const auto & hparams = model.hparams; - const auto & dec = model.dec; - const int n_img_embd = hparams.n_img_embd(); - - struct ggml_tensor * tokens = {}; - { - // Concatenate output tokens - // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L120 - const auto& sparse = prompt.embd_prompt_sparse; - - tokens = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, dec.iou_token_w->ne[0], dec.iou_token_w->ne[1] + dec.mask_tokens_w->ne[1] + sparse->ne[1], sparse->ne[2]); - - const size_t offsets[3] = { 0, dec.iou_token_w->ne[1]*tokens->nb[1], dec.iou_token_w->ne[1]*tokens->nb[1] + dec.mask_tokens_w->ne[1]*tokens->nb[1] }; - ggml_build_forward_expand(gf, ggml_cpy(ctx0, dec.iou_token_w, ggml_view_2d(ctx0, tokens, tokens->ne[0], dec.iou_token_w->ne[1], tokens->nb[1], offsets[0]))); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, dec.mask_tokens_w, ggml_view_2d(ctx0, tokens, tokens->ne[0], dec.mask_tokens_w->ne[1], tokens->nb[1], offsets[1]))); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, sparse, ggml_view_2d(ctx0, tokens, tokens->ne[0], sparse->ne[1], tokens->nb[1], offsets[2]))); - // TODO: Sparse prompt embeddings can have more than one point - } - - - struct ggml_tensor * src = {}; - struct ggml_tensor * pos_src = {}; - int srcNE[4] = { 0, 0, 0, 0 }; - { - // Expand per-image data in the batch direction to be per-mask - // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L125 - src = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, state.embd_img->ne[0], state.embd_img->ne[1], state.embd_img->ne[2], tokens->ne[2]); - - src = ggml_add(ctx0, - ggml_repeat(ctx0, - state.embd_img, - src), - prompt.embd_prompt_dense); - - srcNE[0] = src->ne[0]; - srcNE[1] = src->ne[1]; - srcNE[2] = src->ne[2]; - srcNE[3] = src->ne[3]; - - // flatten & permute - // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L83 - src = ggml_cont(ctx0, ggml_permute(ctx0, - ggml_view_3d(ctx0, - src, - src->ne[0]*src->ne[1], - src->ne[2], - src->ne[3], - src->nb[2], - src->nb[3], - 0), - 1, 0, 2, 3)); - - pos_src = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, pe_img->ne[0], pe_img->ne[1], pe_img->ne[2], tokens->ne[2]); - pos_src = ggml_repeat(ctx0, - pe_img, - pos_src); - - // flatten & permute - // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L83 - pos_src = ggml_cont(ctx0, ggml_permute(ctx0, - ggml_view_3d(ctx0, - pos_src, - pos_src->ne[0]*pos_src->ne[1], - pos_src->ne[2], - pos_src->ne[3], - pos_src->nb[2], - pos_src->nb[3], - 0), - 1, 0, 2, 3)); - } - - struct ggml_tensor * queries = tokens; - struct ggml_tensor * keys = src; - { - // Run the transformer - // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L62 - for (int i = 0; i < int(model.dec.transformer_layers.size()); ++i) { - const auto& tfm_layer = model.dec.transformer_layers[i]; - - // Self attention block - // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L154 - const bool skip_first_layer_pe = i == 0; - if (skip_first_layer_pe) { - queries = sam_decode_mask_transformer_attn(tfm_layer.self_attn, queries, queries, queries, ctx0, model); - } - else { - struct ggml_tensor * q_0 = ggml_add(ctx0, queries, tokens); - - struct ggml_tensor * self_attn = sam_decode_mask_transformer_attn(tfm_layer.self_attn, q_0, q_0, queries, ctx0, model); - queries = ggml_add(ctx0, queries, self_attn); - } - - queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer); - queries = ggml_add_inplace(ctx0, - ggml_mul(ctx0, queries, tfm_layer.norm1_w), - tfm_layer.norm1_b); - - // Cross attention block, tokens attending to image embedding - // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L163 - struct ggml_tensor * q_1 = ggml_add(ctx0, queries, tokens); - struct ggml_tensor * k_1 = ggml_add(ctx0, keys, pos_src); - - struct ggml_tensor * cross_attn_token_to_img = sam_decode_mask_transformer_attn(tfm_layer.cross_attn_token_to_img, q_1, k_1, keys, ctx0, model); - - queries = ggml_add_inplace(ctx0, queries, cross_attn_token_to_img); - queries = ggml_norm_inplace(ctx0, queries, hparams.eps_decoder_transformer); - queries = ggml_add_inplace(ctx0, - ggml_mul(ctx0, queries, tfm_layer.norm2_w), - tfm_layer.norm2_b); - - // MLP block - // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L170 - struct ggml_tensor * mlp_out = ggml_mul_mat(ctx0, - tfm_layer.mlp_lin1_w, - queries); - - mlp_out = ggml_add_inplace(ctx0, mlp_out, tfm_layer.mlp_lin1_b); - - // RELU activation - mlp_out = ggml_relu_inplace(ctx0, mlp_out); - mlp_out = ggml_mul_mat(ctx0, tfm_layer.mlp_lin2_w, mlp_out); - - mlp_out = ggml_add_inplace(ctx0, mlp_out, tfm_layer.mlp_lin2_b); - - queries = ggml_add_inplace(ctx0, queries, mlp_out); - queries = ggml_norm_inplace(ctx0, queries, hparams.eps_decoder_transformer); - queries = ggml_add_inplace(ctx0, - ggml_mul(ctx0, queries, tfm_layer.norm3_w), - tfm_layer.norm3_b); - - // Cross attention block, image embedding attending to tokens - // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L175 - struct ggml_tensor * q_2 = ggml_add(ctx0, queries, tokens); - struct ggml_tensor * k_2 = ggml_add(ctx0, keys, pos_src); - - struct ggml_tensor * cross_attn_img_to_token = sam_decode_mask_transformer_attn(tfm_layer.cross_attn_img_to_token, k_2, q_2, queries, ctx0, model); - keys = ggml_add_inplace(ctx0, keys, cross_attn_img_to_token); - keys = ggml_norm_inplace(ctx0, keys, hparams.eps_decoder_transformer); - keys = ggml_add_inplace(ctx0, - ggml_mul(ctx0, keys, tfm_layer.norm4_w), - tfm_layer.norm4_b); - } - - // Apply the final attention layer from the points to the image - // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L99 - struct ggml_tensor * q = ggml_add(ctx0, queries, tokens); - struct ggml_tensor * k = ggml_add(ctx0, keys, pos_src); - - struct ggml_tensor * final_attn_token_to_img = sam_decode_mask_transformer_attn(dec.transformer_final_attn_token_to_img, q, k, keys, ctx0, model); - - queries = ggml_add_inplace(ctx0, queries, final_attn_token_to_img); - queries = ggml_norm_inplace(ctx0, queries, hparams.eps_decoder_transformer); - queries = ggml_add_inplace(ctx0, - ggml_mul(ctx0, queries, dec.transformer_norm_final_w), - dec.transformer_norm_final_b); - } - - - struct ggml_tensor * iou_pred = ggml_view_2d(ctx0, queries, queries->ne[0], queries->ne[2], queries->nb[2], 0); - const int num_mask_tokens = 4; // num_multimask_outputs + 1 - struct ggml_tensor * mask_tokens_out = ggml_view_3d(ctx0, queries, queries->ne[0], num_mask_tokens, queries->ne[2], queries->nb[1], num_mask_tokens*queries->nb[1], queries->nb[1]); - - // Upscale mask embeddings and predict masks using the mask tokens - // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L136 - keys = ggml_cont(ctx0, ggml_transpose(ctx0, keys)); - keys = ggml_view_4d(ctx0, keys, srcNE[0], srcNE[1], srcNE[2], srcNE[3], srcNE[0]*keys->nb[0], keys->nb[1], keys->nb[2], 0); - // ggml_build_forward_expand(gf, keys); - struct ggml_tensor * upscaled_embedding = {}; - { - // ConvTranspose2d - keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_0_w, keys, 2); - keys = ggml_add_inplace(ctx0, keys, ggml_repeat(ctx0, - ggml_reshape_3d(ctx0, dec.output_upscaling_0_b, 1, 1, dec.output_upscaling_0_b->ne[0]), - keys)); - - keys = sam_layer_norm_2d(ctx0, keys, n_img_embd, dec.output_upscaling_1_w, dec.output_upscaling_1_b, hparams.eps); - - // GELU activation - keys = ggml_gelu_inplace(ctx0, keys); - - // ConvTranspose2d - keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_3_w, keys, 2); - keys = ggml_add_inplace(ctx0, ggml_repeat(ctx0, - ggml_reshape_3d(ctx0, dec.output_upscaling_3_b, 1, 1, dec.output_upscaling_3_b->ne[0]), - keys), keys); - // GELU activation - keys = ggml_gelu_inplace(ctx0, keys); - upscaled_embedding = ggml_reshape_3d(ctx0, keys, keys->ne[0]*keys->ne[1], keys->ne[2], keys->ne[3]); - upscaled_embedding = ggml_cont(ctx0, ggml_transpose(ctx0, upscaled_embedding)); // TODO: Shouldn't be needed - } - - struct ggml_tensor * hyper_in = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_img_embd/2, num_mask_tokens, mask_tokens_out->ne[2]); - - for (int i = 0; i < num_mask_tokens; ++i) { - const auto& mlp = dec.output_hypernet_mlps[i]; - struct ggml_tensor * in = ggml_view_2d(ctx0, mask_tokens_out, mask_tokens_out->ne[0], mask_tokens_out->ne[2], mask_tokens_out->nb[1], i*mask_tokens_out->nb[1]); - struct ggml_tensor * out = sam_decode_mask_mlp_relu_3(in, mlp.w_0, mlp.b_0, mlp.w_1, mlp.b_1, mlp.w_2, mlp.b_2, ctx0); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, out, ggml_view_2d(ctx0, hyper_in, hyper_in->ne[0], hyper_in->ne[2], hyper_in->nb[1], i*hyper_in->nb[1]))); - } - - struct ggml_tensor * masks = ggml_mul_mat(ctx0, hyper_in, upscaled_embedding); - masks = ggml_cont(ctx0, ggml_transpose(ctx0, masks)); // TODO: Shouldn't be needed - masks = ggml_reshape_4d(ctx0, masks, keys->ne[0], keys->ne[1], masks->ne[1], keys->ne[3]); - - // Generate mask quality predictions - // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L146 - iou_pred = sam_decode_mask_mlp_relu_3(iou_pred, dec.iou_prediction_head_0_w, dec.iou_prediction_head_0_b, dec.iou_prediction_head_1_w, dec.iou_prediction_head_1_b, dec.iou_prediction_head_2_w, dec.iou_prediction_head_2_b, ctx0); - - // Select the correct mask or masks for output - // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L101 - iou_pred = ggml_cpy(state.ctx, ggml_view_1d(ctx0, iou_pred, iou_pred->ne[0] - 1, iou_pred->nb[0]), state.iou_predictions); - masks = ggml_view_4d(ctx0, masks, masks->ne[0], masks->ne[1], masks->ne[2] - 1, masks->ne[3], - masks->nb[1], masks->nb[2], masks->nb[3], masks->nb[2] /* offset*/); - masks = ggml_cpy(state.ctx, masks, state.low_res_masks); - - ggml_build_forward_expand(gf, masks); - ggml_build_forward_expand(gf, iou_pred); - - ggml_disconnect_node_from_graph(state.low_res_masks); - ggml_disconnect_node_from_graph(state.iou_predictions); - - return true; -} - -bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state & state, const std::string & fname) { - if (state.low_res_masks->ne[2] == 0) return true; - if (state.low_res_masks->ne[2] != state.iou_predictions->ne[0]) { - printf("Error: number of masks (%d) does not match number of iou predictions (%d)\n", (int)state.low_res_masks->ne[2], (int)state.iou_predictions->ne[0]); - return false; - } - - const int n_img_size = hparams.n_img_size(); - const float mask_threshold = hparams.mask_threshold; - const float iou_threshold = hparams.iou_threshold; - const float stability_score_threshold = hparams.stability_score_threshold; - const float intersection_threshold = mask_threshold + hparams.stability_score_offset; - const float union_threshold = mask_threshold - hparams.stability_score_offset; - - const int ne0 = state.low_res_masks->ne[0]; - const int ne1 = state.low_res_masks->ne[1]; - const int ne2 = state.low_res_masks->ne[2]; - - // Remove padding and upscale masks to the original image size. - // ref: https://github.com/facebookresearch/segment-anything/blob/efeab7296ab579d4a261e554eca80faf6b33924a/segment_anything/modeling/sam.py#L140 - - const float preprocess_scale = std::max(nx, ny) / float(n_img_size); - const int cropped_nx = int(nx / preprocess_scale + 0.5f); - const int cropped_ny = int(ny / preprocess_scale + 0.5f); - - const float scale_x_1 = (float)ne0 / (float)n_img_size; - const float scale_y_1 = (float)ne1 / (float)n_img_size; - - const float scale_x_2 = float(cropped_nx) / float(nx); - const float scale_y_2 = float(cropped_ny) / float(ny); - - const auto iou_data = (float*)state.iou_predictions->data; - - for (int i = 0; i < ne2; ++i) { - if (iou_threshold > 0.f && iou_data[i] < iou_threshold) { - printf("Skipping mask %d with iou %f below threshold %f\n", i, iou_data[i], iou_threshold); - continue; // Filtering masks with iou below the threshold - } - - std::vector mask_data(n_img_size*n_img_size); - { - const float* data = (float *) state.low_res_masks->data + i*ne0*ne1; - - for (int iy = 0; iy < n_img_size; ++iy) { - for (int ix = 0; ix < n_img_size; ++ix) { - const float sx = std::max(scale_x_1*(ix + 0.5f) - 0.5f, 0.0f); - const float sy = std::max(scale_y_1*(iy + 0.5f) - 0.5f, 0.0f); - - const int x0 = std::max(0, (int)sx); - const int y0 = std::max(0, (int)sy); - - const int x1 = std::min(x0 + 1, ne0 - 1); - const int y1 = std::min(y0 + 1, ne1 - 1); - - const float dx = sx - x0; - const float dy = sy - y0; - - const int j00 = y0*ne0 + x0; - const int j01 = y0*ne0 + x1; - const int j10 = y1*ne0 + x0; - const int j11 = y1*ne0 + x1; - - const float v00 = data[j00]; - const float v01 = data[j01]; - const float v10 = data[j10]; - const float v11 = data[j11]; - - const float v0 = (1-dx)*v00 + dx*v01; - const float v1 = (1-dx)*v10 + dx*v11; - - const float v = (1-dy)*v0 + dy*v1; - - mask_data[iy*n_img_size + ix] = v; - } - } - } - - int intersections = 0; - int unions = 0; - sam_image_u8 res; - int min_iy = ny; - int max_iy = 0; - int min_ix = nx; - int max_ix = 0; - { - const float* data = mask_data.data(); - - res.nx = nx; - res.ny = ny; - res.data.resize(nx*ny); - - for (int iy = 0; iy < ny; ++iy) { - for (int ix = 0; ix < nx; ++ix) { - const float sx = std::max(scale_x_2*(ix + 0.5f) - 0.5f, 0.0f); - const float sy = std::max(scale_y_2*(iy + 0.5f) - 0.5f, 0.0f); - - const int x0 = std::max(0, (int)sx); - const int y0 = std::max(0, (int)sy); - - const int x1 = std::min(x0 + 1, cropped_nx - 1); - const int y1 = std::min(y0 + 1, cropped_ny - 1); - - const float dx = sx - x0; - const float dy = sy - y0; - - const int j00 = y0*n_img_size + x0; - const int j01 = y0*n_img_size + x1; - const int j10 = y1*n_img_size + x0; - const int j11 = y1*n_img_size + x1; - - const float v00 = data[j00]; - const float v01 = data[j01]; - const float v10 = data[j10]; - const float v11 = data[j11]; - - const float v0 = (1-dx)*v00 + dx*v01; - const float v1 = (1-dx)*v10 + dx*v11; - - const float v = (1-dy)*v0 + dy*v1; - - if (v > intersection_threshold) { - intersections++; - } - if (v > union_threshold) { - unions++; - } - if (v > mask_threshold) { - min_iy = std::min(min_iy, iy); - max_iy = std::max(max_iy, iy); - min_ix = std::min(min_ix, ix); - max_ix = std::max(max_ix, ix); - - res.data[iy*nx + ix] = 255; - } - } - } - } - - const float stability_score = float(intersections) / float(unions); - if (stability_score_threshold > 0.f && stability_score < stability_score_threshold) { - printf("Skipping mask %d with stability score %f below threshold %f\n", i, stability_score, stability_score_threshold); - continue; // Filtering masks with stability score below the threshold - } - - printf("Mask %d: iou = %f, stability_score = %f, bbox (%d, %d), (%d, %d)\n", - i, iou_data[i], stability_score, min_ix, max_ix, min_iy, max_iy); - - std::string filename = fname + std::to_string(i) + ".png"; - if (!stbi_write_png(filename.c_str(), res.nx, res.ny, 1, res.data.data(), res.nx)) { - printf("%s: failed to write mask %s\n", __func__, filename.c_str()); - return false; - } - } - - - return true; -} - -struct ggml_cgraph * sam_build_fast_graph( - const sam_model & model, - sam_state & state, - int nx, - int ny, - sam_point point) { - - struct ggml_init_params ggml_params = { - /*.mem_size =*/ state.buf_compute_fast.size(), - /*.mem_buffer =*/ state.buf_compute_fast.data(), - /*.no_alloc =*/ true, // skip allocating as we use ggml_alloc to allocate exact memory requirements - }; - - struct ggml_context * ctx0 = ggml_init(ggml_params); - struct ggml_cgraph * gf = ggml_new_graph(ctx0); - - prompt_encoder_result enc_res = sam_encode_prompt(model, ctx0, gf, state); - if (!enc_res.embd_prompt_sparse || !enc_res.embd_prompt_dense) { - fprintf(stderr, "%s: failed to encode prompt (%f, %f)\n", __func__, point.x, point.y); - return {}; - } - - struct ggml_tensor * pe_img_dense = sam_fill_dense_pe(model, ctx0, gf, state); - if (!pe_img_dense) { - fprintf(stderr, "%s: failed to get dense positional encoding\n", __func__); - return {}; - } - - if (!sam_decode_mask(model, enc_res, pe_img_dense, ctx0, gf, state)) { - fprintf(stderr, "%s: failed to decode mask\n", __func__); - return {}; - } - - ggml_free(ctx0); - - ggml_gallocr_alloc_graph(state.allocr, gf); - - // from sam_encode_prompt - { - // transform points - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py#L276 - { - const int nmax = std::max(nx, ny); - - const float scale = model.hparams.n_img_size() / (float) nmax; - - const int nx_new = int(nx*scale + 0.5f); - const int ny_new = int(ny*scale + 0.5f); - - point.x = point.x*(float(nx_new)/nx) + 0.5f; - point.y = point.y*(float(ny_new)/ny) + 0.5f; - } - - struct ggml_tensor * inp = ggml_graph_get_tensor(gf, "prompt_input"); - // set the input by converting the [0, 1] coordinates to [-1, 1] - float * data = (float *) inp->data; - - data[0] = 2.0f*(point.x / model.hparams.n_img_size()) - 1.0f; - data[1] = 2.0f*(point.y / model.hparams.n_img_size()) - 1.0f; - - // padding - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L81-L85 - data[2] = 2.0f*(0.0f) - 1.0f; - data[3] = 2.0f*(0.0f) - 1.0f; - } - - // from sam_fill_dense_pe - { - struct ggml_tensor * xy_embed_stacked = ggml_graph_get_tensor(gf, "xy_embed_stacked"); - const int32_t n_img_embd = model.hparams.n_img_embd(); - const float n_img_embd_inv = 1.0f / n_img_embd; - float * data = (float *) ggml_get_data(xy_embed_stacked); - for (int i = 0; i < n_img_embd; ++i) { - const int row = 2*i*n_img_embd; - const float y_val = 2 * (i + 0.5f) * n_img_embd_inv - 1; - for (int j = 0; j < n_img_embd; ++j) { - const float x_val = 2 * (j + 0.5f) * n_img_embd_inv - 1; - data[row + 2*j + 0] = x_val; - data[row + 2*j + 1] = y_val; - } - } - } - - return gf; -} - -void sam_print_usage(int argc, char ** argv, const sam_params & params) { - fprintf(stderr, "usage: %s [options]\n", argv[0]); - fprintf(stderr, "\n"); - fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help show this help message and exit\n"); - fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); - fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); - fprintf(stderr, " -m FNAME, --model FNAME\n"); - fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); - fprintf(stderr, " -i FNAME, --inp FNAME\n"); - fprintf(stderr, " input file (default: %s)\n", params.fname_inp.c_str()); - fprintf(stderr, " -o FNAME, --out FNAME\n"); - fprintf(stderr, " mask file name prefix (default: %s)\n", params.fname_out.c_str()); - fprintf(stderr, "SAM hyperparameters:\n"); - fprintf(stderr, " -mt FLOAT, --mask-threshold\n"); - fprintf(stderr, " mask threshold (default: %f)\n", params.mask_threshold); - fprintf(stderr, " -it FLOAT, --iou-threshold\n"); - fprintf(stderr, " iou threshold (default: %f)\n", params.iou_threshold); - fprintf(stderr, " -st FLOAT, --score-threshold\n"); - fprintf(stderr, " score threshold (default: %f)\n", params.stability_score_threshold); - fprintf(stderr, " -so FLOAT, --score-offset\n"); - fprintf(stderr, " score offset (default: %f)\n", params.stability_score_offset); - fprintf(stderr, " -e FLOAT, --epsilon\n"); - fprintf(stderr, " epsilon (default: %f)\n", params.eps); - fprintf(stderr, " -ed FLOAT, --epsilon-decoder-transformer\n"); - fprintf(stderr, " epsilon decoder transformer (default: %f)\n", params.eps_decoder_transformer); - fprintf(stderr, "SAM prompt:\n"); - fprintf(stderr, " -p TUPLE, --point-prompt\n"); - fprintf(stderr, " point to be used as prompt for SAM (default: %f,%f). Must be in a format FLOAT,FLOAT \n", params.pt.x, params.pt.y); - fprintf(stderr, "\n"); -} - -bool sam_params_parse(int argc, char ** argv, sam_params & params) { - for (int i = 1; i < argc; i++) { - std::string arg = argv[i]; - - if (arg == "-s" || arg == "--seed") { - params.seed = std::stoi(argv[++i]); - } else if (arg == "-t" || arg == "--threads") { - params.n_threads = std::stoi(argv[++i]); - } else if (arg == "-m" || arg == "--model") { - params.model = argv[++i]; - } else if (arg == "-i" || arg == "--inp") { - params.fname_inp = argv[++i]; - } else if (arg == "-o" || arg == "--out") { - params.fname_out = argv[++i]; - } else if (arg == "-mt" || arg == "--mask-threshold") { - params.mask_threshold = std::stof(argv[++i]); - } else if (arg == "-it" || arg == "--iou-threshold") { - params.iou_threshold = std::stof(argv[++i]); - } else if (arg == "-st" || arg == "--score-threshold") { - params.stability_score_threshold = std::stof(argv[++i]); - } else if (arg == "-so" || arg == "--score-offset") { - params.stability_score_offset = std::stof(argv[++i]); - } else if (arg == "-e" || arg == "--epsilon") { - params.eps = std::stof(argv[++i]); - } else if (arg == "-ed" || arg == "--epsilon-decoder-transformer") { - params.eps_decoder_transformer = std::stof(argv[++i]); - } else if (arg == "-p" || arg == "--point-prompt") { - // TODO multiple points per model invocation - char* point = argv[++i]; - - char* coord = strtok(point, ","); - if (!coord){ - fprintf(stderr, "Error while parsing prompt!\n"); - exit(1); - } - params.pt.x = std::stof(coord); - coord = strtok(NULL, ","); - if (!coord){ - fprintf(stderr, "Error while parsing prompt!\n"); - exit(1); - } - params.pt.y = std::stof(coord); - } else if (arg == "-h" || arg == "--help") { - sam_print_usage(argc, argv, params); - exit(0); - } else { - fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); - sam_print_usage(argc, argv, params); - exit(0); - } - } - - return true; -} - -int main(int argc, char ** argv) { - const int64_t t_main_start_us = ggml_time_us(); - - sam_params params; - params.model = "models/sam-vit-b/ggml-model-f16.bin"; - - sam_model model; - sam_state state; - int64_t t_load_us = 0; - - if (sam_params_parse(argc, argv, params) == false) { - return 1; - } - - if (params.seed < 0) { - params.seed = time(NULL); - } - fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); - - // load the image - sam_image_u8 img0; - if (!sam_image_load_from_file(params.fname_inp, img0)) { - fprintf(stderr, "%s: failed to load image from '%s'\n", __func__, params.fname_inp.c_str()); - return 1; - } - fprintf(stderr, "%s: loaded image '%s' (%d x %d)\n", __func__, params.fname_inp.c_str(), img0.nx, img0.ny); - - // preprocess to f32 - sam_image_f32 img1; - if (!sam_image_preprocess(img0, img1)) { - fprintf(stderr, "%s: failed to preprocess image\n", __func__); - return 1; - } - fprintf(stderr, "%s: preprocessed image (%d x %d)\n", __func__, img1.nx, img1.ny); - - - // load the model - { - const int64_t t_start_us = ggml_time_us(); - - if (!sam_model_load(params, model)) { - fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); - return 1; - } - - t_load_us = ggml_time_us() - t_start_us; - } - - { - static size_t buf_size = 256u*1024*1024; - - struct ggml_init_params ggml_params = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ false, - }; - - state.ctx = ggml_init(ggml_params); - - state.embd_img = ggml_new_tensor_3d(state.ctx, GGML_TYPE_F32, - model.hparams.n_img_embd(), model.hparams.n_img_embd(), model.hparams.n_enc_out_chans); - - state.low_res_masks = ggml_new_tensor_3d(state.ctx, GGML_TYPE_F32, - model.hparams.n_enc_out_chans, model.hparams.n_enc_out_chans, 3); - - state.iou_predictions = ggml_new_tensor_1d(state.ctx, GGML_TYPE_F32, 3); - } - - - { - state.buf_compute_img_enc.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead()); - state.allocr = ggml_gallocr_new(ggml_backend_cpu_buffer_type()); - - struct ggml_cgraph * gf = sam_encode_image(model, state, img1); - if (!gf) { - fprintf(stderr, "%s: failed to encode image\n", __func__); - return 1; - } - - ggml_graph_compute_helper(state.work_buffer, gf, params.n_threads); - - print_t_f32("embd_img", state.embd_img); - - ggml_gallocr_free(state.allocr); - state.allocr = NULL; - state.work_buffer.clear(); - } - { - state.buf_compute_fast.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead()); - state.allocr = ggml_gallocr_new(ggml_backend_cpu_buffer_type()); - - // TODO: more varied prompts - fprintf(stderr, "prompt: (%f, %f)\n", params.pt.x, params.pt.y); - - struct ggml_cgraph * gf = sam_build_fast_graph(model, state, img0.nx, img0.ny, params.pt); - if (!gf) { - fprintf(stderr, "%s: failed to build fast graph\n", __func__); - return 1; - } - - ggml_graph_compute_helper(state.work_buffer, gf, params.n_threads); - - //print_t_f32("iou_predictions", state.iou_predictions); - //print_t_f32("low_res_masks", state.low_res_masks); - ggml_gallocr_free(state.allocr); - state.allocr = NULL; - } - - if (!sam_write_masks(model.hparams, img0.nx, img0.ny, state, params.fname_out)) { - fprintf(stderr, "%s: failed to write masks\n", __func__); - return 1; - } - - // report timing - { - const int64_t t_main_end_us = ggml_time_us(); - - fprintf(stderr, "\n\n"); - fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); - fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); - } - - ggml_free(model.ctx); - - return 0; -} diff --git a/examples/sam/sam.cpp b/examples/sam/sam.cpp new file mode 100644 index 00000000..299318f5 --- /dev/null +++ b/examples/sam/sam.cpp @@ -0,0 +1,2365 @@ +#define _USE_MATH_DEFINES // for M_PI +#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows + +#include "ggml.h" +#include "ggml-cpu.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#define STB_IMAGE_IMPLEMENTATION +#include "stb_image.h" +#define STB_IMAGE_WRITE_IMPLEMENTATION +#include "stb_image_write.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +// default hparams (ViT-B SAM) +struct sam_hparams { + int32_t n_enc_state = 768; + int32_t n_enc_layer = 12; + int32_t n_enc_head = 12; + int32_t n_enc_out_chans = 256; + int32_t n_pt_embd = 4; + int32_t n_dec_heads = 8; + int32_t ftype = 1; + float mask_threshold = 0.f; + float iou_threshold = 0.88f; + float stability_score_threshold = 0.95f; + float stability_score_offset = 1.0f; + float eps = 1e-6f; + float eps_decoder_transformer = 1e-5f; + + int32_t n_enc_head_dim() const { return n_enc_state / n_enc_head; } + int32_t n_img_size() const { return 1024; } + int32_t n_window_size() const { return 14; } + int32_t n_patch_size() const { return 16; } + int32_t n_img_embd() const { return n_img_size() / n_patch_size(); } + + std::vector global_attn_indices() const { + switch (n_enc_state) { + case 768: return { 2, 5, 8, 11 }; + case 1024: return { 5, 11, 17, 23 }; + case 1280: return { 7, 15, 23, 31 }; + default: + { + fprintf(stderr, "%s: unsupported n_enc_state = %d\n", __func__, n_enc_state); + } break; + }; + + return {}; + } + + bool is_global_attn(int32_t layer) const { + const auto indices = global_attn_indices(); + + for (const auto & idx : indices) { + if (layer == idx) { + return true; + } + } + + return false; + } +}; + +struct sam_layer_enc { + struct ggml_tensor * norm1_w; + struct ggml_tensor * norm1_b; + + struct ggml_tensor * rel_pos_w; + struct ggml_tensor * rel_pos_h; + + struct ggml_tensor * qkv_w; + struct ggml_tensor * qkv_b; + + struct ggml_tensor * proj_w; + struct ggml_tensor * proj_b; + + struct ggml_tensor * norm2_w; + struct ggml_tensor * norm2_b; + + struct ggml_tensor * mlp_lin1_w; + struct ggml_tensor * mlp_lin1_b; + + struct ggml_tensor * mlp_lin2_w; + struct ggml_tensor * mlp_lin2_b; +}; + +struct sam_encoder_image { + struct ggml_tensor * pe; + + struct ggml_tensor * proj_w; + struct ggml_tensor * proj_b; + + struct ggml_tensor * neck_conv_0; + struct ggml_tensor * neck_norm_0_w; + struct ggml_tensor * neck_norm_0_b; + struct ggml_tensor * neck_conv_1; + struct ggml_tensor * neck_norm_1_w; + struct ggml_tensor * neck_norm_1_b; + + std::vector layers; +}; + +struct sam_encoder_prompt { + struct ggml_tensor * pe; + + struct ggml_tensor * not_a_pt_embd_w; + std::vector pt_embd; + + struct ggml_tensor * no_mask_embd_w; + //std::vector mask_down_w; + //std::vector mask_down_b; +}; + +struct sam_layer_dec_transformer_attn { + // q_proj + struct ggml_tensor * q_w; + struct ggml_tensor * q_b; + + // k_proj + struct ggml_tensor * k_w; + struct ggml_tensor * k_b; + + // v_proj + struct ggml_tensor * v_w; + struct ggml_tensor * v_b; + + // out_proj + struct ggml_tensor * out_w; + struct ggml_tensor * out_b; +}; + +struct sam_layer_dec_transformer { + sam_layer_dec_transformer_attn self_attn; + + // norm1 + struct ggml_tensor * norm1_w; + struct ggml_tensor * norm1_b; + + sam_layer_dec_transformer_attn cross_attn_token_to_img; + + // norm2 + struct ggml_tensor * norm2_w; + struct ggml_tensor * norm2_b; + + // mlp.lin1 + struct ggml_tensor * mlp_lin1_w; + struct ggml_tensor * mlp_lin1_b; + + // mlp.lin2 + struct ggml_tensor * mlp_lin2_w; + struct ggml_tensor * mlp_lin2_b; + + // norm3 + struct ggml_tensor * norm3_w; + struct ggml_tensor * norm3_b; + + // norm4 + struct ggml_tensor * norm4_w; + struct ggml_tensor * norm4_b; + + sam_layer_dec_transformer_attn cross_attn_img_to_token; +}; + +struct sam_layer_dec_output_hypernet_mlps { + // mlps_*.layers.0 + struct ggml_tensor * w_0; + struct ggml_tensor * b_0; + + // mlps_*.layers.1 + struct ggml_tensor * w_1; + struct ggml_tensor * b_1; + + // mlps_*.layers.2 + struct ggml_tensor * w_2; + struct ggml_tensor * b_2; +}; + +struct sam_decoder_mask { + std::vector transformer_layers; + + // trasnformer.final_attn_token_to_image + sam_layer_dec_transformer_attn transformer_final_attn_token_to_img; + + // transformer.norm_final + struct ggml_tensor * transformer_norm_final_w; + struct ggml_tensor * transformer_norm_final_b; + + // output_upscaling.0 + struct ggml_tensor * output_upscaling_0_w; + struct ggml_tensor * output_upscaling_0_b; + + // output_upscaling.1 + struct ggml_tensor * output_upscaling_1_w; + struct ggml_tensor * output_upscaling_1_b; + + // output_upscaling.3 + struct ggml_tensor * output_upscaling_3_w; + struct ggml_tensor * output_upscaling_3_b; + + // output_hypernetworks_mlps + std::vector output_hypernet_mlps; + + // iou_prediction_head.0 + struct ggml_tensor * iou_prediction_head_0_w; + struct ggml_tensor * iou_prediction_head_0_b; + + // iou_prediction_head.1 + struct ggml_tensor * iou_prediction_head_1_w; + struct ggml_tensor * iou_prediction_head_1_b; + + // iou_prediction_head.2 + struct ggml_tensor * iou_prediction_head_2_w; + struct ggml_tensor * iou_prediction_head_2_b; + + // iou_token.weight + struct ggml_tensor * iou_token_w; + + // mask_tokens.weight + struct ggml_tensor * mask_tokens_w; +}; + + +struct sam_state { + struct ggml_tensor * embd_img; + + struct ggml_tensor * low_res_masks; + struct ggml_tensor * iou_predictions; + + //struct ggml_tensor * tmp_save = {}; + + struct ggml_context * ctx; + + // buffer for `ggml_graph_plan.work_data` + std::vector work_buffer; + // buffers to evaluate the model + std::vector buf_compute_img_enc; + + std::vector buf_compute_fast; + + ggml_gallocr_t allocr = {}; +}; + +// void save_tensor(sam_state& state, struct ggml_tensor * t, struct ggml_cgraph * gf) { +// if (!state.tmp_save) { +// state.tmp_save = ggml_new_tensor(state.ctx, t->type, t->n_dims, t->ne); +// } +// struct ggml_tensor * tmp0 = ggml_cpy(state.ctx, t, state.tmp_save); +// ggml_build_forward_expand(gf, tmp0); +// } + +struct sam_model { + sam_hparams hparams; + + sam_encoder_image enc_img; + sam_encoder_prompt enc_prompt; + sam_decoder_mask dec; + + // + struct ggml_context * ctx; + std::map tensors; +}; + +struct sam_point { + float x; + float y; +}; + +struct sam_box { + float x1; + float y1; + float x2; + float y2; +}; + +// RGB uint8 image +struct sam_image_u8 { + int nx; + int ny; + + std::vector data; +}; + +// RGB float32 image +// Memory layout: RGBRGBRGB... +struct sam_image_f32 { + int nx; + int ny; + + std::vector data; +}; + +enum sam_prompt_type { + SAM_PROMPT_TYPE_POINT = 0, + SAM_PROMPT_TYPE_BOX = 1, +}; + +struct sam_prompt { + sam_prompt_type prompt_type = SAM_PROMPT_TYPE_POINT; + sam_point pt = { 414.375f, 162.796875f, }; + sam_box box = { 368.0f, 144.0f, 441.0f, 173.0f }; +}; + +struct sam_params { + int32_t seed = -1; // RNG seed + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + + std::string model = "models/sam-vit-b/ggml-model-f16.bin"; // model path + std::string fname_inp = "img.jpg"; + std::string fname_out = "img.out"; + float mask_threshold = 0.f; + float iou_threshold = 0.88f; + float stability_score_threshold = 0.95f; + float stability_score_offset = 1.0f; + float eps = 1e-6f; + float eps_decoder_transformer = 1e-5f; + + sam_prompt prompt; + bool multimask_output = true; +}; + +void print_t_f32(const char* title, struct ggml_tensor * t, int n = 10) { + printf("%s\n", title); + float * data = (float *)t->data; + printf("dims: % " PRId64 " % " PRId64 " % " PRId64 " % " PRId64 " f32\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]); + printf("First & Last %d elements:\n", n); + for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) { + printf("%.5f ", data[i]); + if (i != 0 && i % t->ne[0] == 0) { + printf("\n"); + } + } + printf("\n"); + for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) { + printf("%.5f ", data[ggml_nelements(t) - n + i]); + if ((ggml_nelements(t) - n + i) % t->ne[0] == 0) { + printf("\n"); + } + } + printf("\n"); + double sum = 0.0; + for (int i = 0; i < ggml_nelements(t); i++) { + sum += data[i]; + } + printf("sum: %f\n\n", sum); +} + +static void ggml_disconnect_node_from_graph(ggml_tensor * t) { + t->op = GGML_OP_NONE; + for (int i = 0; i < GGML_MAX_SRC; i++) { + t->src[i] = NULL; + } +} + +static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads) { + struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr); + + if (plan.work_size > 0) { + buf.resize(plan.work_size); + plan.work_data = buf.data(); + } + + ggml_graph_compute(graph, &plan); +} + +static void ggml_sam_sin(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) { + GGML_ASSERT(userdata == NULL); + GGML_ASSERT(ggml_are_same_shape(dst, src)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src)); + + const float * src_data = ggml_get_data_f32(src); + float * dst_data = ggml_get_data_f32(dst); + + const int ne = (int)ggml_nelements(dst); + const int dr = (ne + nth - 1) / nth; + const int ie0 = dr * ith; + const int ie1 = std::min(ie0 + dr, ne); + + for (int i = ie0; i < ie1; ++i) { + dst_data[i] = sinf(src_data[i]); + } +} + +static void ggml_sam_cos(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) { + GGML_ASSERT(userdata == NULL); + GGML_ASSERT(ggml_are_same_shape(dst, src)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src)); + + const float * src_data = ggml_get_data_f32(src); + float * dst_data = ggml_get_data_f32(dst); + + const int ne = (int)ggml_nelements(dst); + const int dr = (ne + nth - 1) / nth; + const int ie0 = dr * ith; + const int ie1 = std::min(ie0 + dr, ne); + + for (int i = ie0; i < ie1; ++i) { + dst_data[i] = cosf(src_data[i]); + } +} + +bool sam_image_load_from_file(const std::string & fname, sam_image_u8 & img) { + int nx, ny, nc; + auto data = stbi_load(fname.c_str(), &nx, &ny, &nc, 3); + if (!data) { + fprintf(stderr, "%s: failed to load '%s'\n", __func__, fname.c_str()); + return false; + } + + img.nx = nx; + img.ny = ny; + img.data.resize(nx * ny * 3); + memcpy(img.data.data(), data, nx * ny * 3); + + stbi_image_free(data); + + return true; +} + +// ref: https://github.com/facebookresearch/segment-anything/blob/efeab7296ab579d4a261e554eca80faf6b33924a/segment_anything/modeling/sam.py#L164 +// resize largest dimension to 1024 +// normalize: x = (x - mean) / std +// mean = [123.675, 116.28, 103.53] +// std = [58.395, 57.12, 57.375] +// TODO: why are these hardcoded !? +// pad to 1024x1024 +// TODO: for some reason, this is not numerically identical to pytorch's interpolation +bool sam_image_preprocess(const sam_image_u8 & img, sam_image_f32 & res) { + const int nx = img.nx; + const int ny = img.ny; + + const int nx2 = 1024; + const int ny2 = 1024; + + res.nx = nx2; + res.ny = ny2; + res.data.resize(3*nx2*ny2); + + const float scale = std::max(nx, ny) / 1024.0f; + + fprintf(stderr, "%s: scale = %f\n", __func__, scale); + + const int nx3 = int(nx/scale + 0.5f); + const int ny3 = int(ny/scale + 0.5f); + + const float m3[3] = { 123.675f, 116.280f, 103.530f }; + const float s3[3] = { 58.395f, 57.120f, 57.375f }; + + for (int y = 0; y < ny3; y++) { + for (int x = 0; x < nx3; x++) { + for (int c = 0; c < 3; c++) { + // linear interpolation + const float sx = (x + 0.5f)*scale - 0.5f; + const float sy = (y + 0.5f)*scale - 0.5f; + + const int x0 = std::max(0, (int) std::floor(sx)); + const int y0 = std::max(0, (int) std::floor(sy)); + + const int x1 = std::min(x0 + 1, nx - 1); + const int y1 = std::min(y0 + 1, ny - 1); + + const float dx = sx - x0; + const float dy = sy - y0; + + const int j00 = 3*(y0*nx + x0) + c; + const int j01 = 3*(y0*nx + x1) + c; + const int j10 = 3*(y1*nx + x0) + c; + const int j11 = 3*(y1*nx + x1) + c; + + const float v00 = img.data[j00]; + const float v01 = img.data[j01]; + const float v10 = img.data[j10]; + const float v11 = img.data[j11]; + + const float v0 = v00*(1.0f - dx) + v01*dx; + const float v1 = v10*(1.0f - dx) + v11*dx; + + const float v = v0*(1.0f - dy) + v1*dy; + + const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f); + + const int i = 3*(y*nx3 + x) + c; + + res.data[i] = (float(v2) - m3[c]) / s3[c]; + } + } + } + + return true; +} + +// load the model's weights from a file +bool sam_model_load(const sam_params & params, sam_model & model) { + fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, params.model.c_str()); + + auto fin = std::ifstream(params.model, std::ios::binary); + if (!fin) { + fprintf(stderr, "%s: failed to open '%s'\n", __func__, params.model.c_str()); + return false; + } + + // verify magic + { + uint32_t magic; + fin.read((char *) &magic, sizeof(magic)); + if (magic != 0x67676d6c) { + fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, params.model.c_str()); + return false; + } + } + + // load hparams + { + // Override defaults with user choices + model.hparams.mask_threshold = params.mask_threshold; + model.hparams.iou_threshold = params.iou_threshold; + model.hparams.stability_score_threshold = params.stability_score_threshold; + model.hparams.stability_score_offset = params.stability_score_offset; + model.hparams.eps = params.eps; + model.hparams.eps_decoder_transformer = params.eps_decoder_transformer; + + auto & hparams = model.hparams; + + fin.read((char *) &hparams.n_enc_state, sizeof(hparams.n_enc_state)); + fin.read((char *) &hparams.n_enc_layer, sizeof(hparams.n_enc_layer)); + fin.read((char *) &hparams.n_enc_head, sizeof(hparams.n_enc_head)); + fin.read((char *) &hparams.n_enc_out_chans, sizeof(hparams.n_enc_out_chans)); + fin.read((char *) &hparams.n_pt_embd, sizeof(hparams.n_pt_embd)); + fin.read((char *) &hparams.ftype, sizeof(hparams.ftype)); + + const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR; + + printf("%s: n_enc_state = %d\n", __func__, hparams.n_enc_state); + printf("%s: n_enc_layer = %d\n", __func__, hparams.n_enc_layer); + printf("%s: n_enc_head = %d\n", __func__, hparams.n_enc_head); + printf("%s: n_enc_out_chans = %d\n", __func__, hparams.n_enc_out_chans); + printf("%s: n_pt_embd = %d\n", __func__, hparams.n_pt_embd); + printf("%s: ftype = %d\n", __func__, hparams.ftype); + printf("%s: qntvr = %d\n", __func__, qntvr); + + hparams.ftype %= GGML_QNT_VERSION_FACTOR; + + } + + // for the big tensors, we have the option to store the data in 16-bit floats or quantized + // in order to save memory and also to speed up the computation + ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype)); + if (wtype == GGML_TYPE_COUNT) { + fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n", + __func__, params.model.c_str(), model.hparams.ftype); + return false; + } + + auto & ctx = model.ctx; + + const size_t ctx_size = [&]() { + size_t ctx_size = 0; + + const auto & hparams = model.hparams; + + const int32_t n_enc_state = hparams.n_enc_state; + const int32_t n_enc_layer = hparams.n_enc_layer; + const int32_t n_enc_head_dim = hparams.n_enc_head_dim(); + const int32_t n_enc_out_chans = hparams.n_enc_out_chans; + const int32_t n_pt_embd = hparams.n_pt_embd; + + const int32_t n_enc_layer_local = hparams.global_attn_indices().size(); + const int32_t n_enc_layer_global = n_enc_layer - n_enc_layer_local; + + const int32_t n_img_embd = hparams.n_img_embd(); + const int32_t n_window_size = hparams.n_window_size(); + const int32_t n_patch_size = hparams.n_patch_size(); + + // image encoder + { + ctx_size += n_enc_state*n_img_embd*n_img_embd*ggml_type_size(GGML_TYPE_F32); + + ctx_size += n_enc_state*3*n_patch_size*n_patch_size*ggml_type_size(GGML_TYPE_F16); + ctx_size += n_enc_state*ggml_type_size(GGML_TYPE_F32); + + ctx_size += n_enc_state*n_enc_out_chans*1*1*ggml_type_size(GGML_TYPE_F16); + ctx_size += n_enc_out_chans*n_enc_out_chans*3*3*ggml_type_size(GGML_TYPE_F16); + + ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F32); + ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F32); + + ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F32); + ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F32); + } + + // image encoder layers + { + ctx_size += n_enc_layer*n_enc_state*ggml_type_size(GGML_TYPE_F32); + ctx_size += n_enc_layer*n_enc_state*ggml_type_size(GGML_TYPE_F32); + + ctx_size += n_enc_layer_global*n_enc_head_dim*(2*n_img_embd - 1)*ggml_type_size(GGML_TYPE_F16); + ctx_size += n_enc_layer_global*n_enc_head_dim*(2*n_img_embd - 1)*ggml_type_size(GGML_TYPE_F16); + + ctx_size += n_enc_layer_local*n_enc_head_dim*(2*n_window_size - 1)*ggml_type_size(GGML_TYPE_F16); + ctx_size += n_enc_layer_local*n_enc_head_dim*(2*n_window_size - 1)*ggml_type_size(GGML_TYPE_F16); + + ctx_size += n_enc_layer*3*n_enc_state*n_enc_state*ggml_type_size(GGML_TYPE_F16); + ctx_size += n_enc_layer*3*n_enc_state* ggml_type_size(GGML_TYPE_F32); + + ctx_size += n_enc_layer*n_enc_state*n_enc_state*ggml_type_size(GGML_TYPE_F16); + ctx_size += n_enc_layer*n_enc_state* ggml_type_size(GGML_TYPE_F32); + + ctx_size += n_enc_layer*n_enc_state*ggml_type_size(GGML_TYPE_F32); + ctx_size += n_enc_layer*n_enc_state*ggml_type_size(GGML_TYPE_F32); + + ctx_size += n_enc_layer*4*n_enc_state*n_enc_state*ggml_type_size(GGML_TYPE_F16); + ctx_size += n_enc_layer*4*n_enc_state* ggml_type_size(GGML_TYPE_F32); + + ctx_size += n_enc_layer*4*n_enc_state*n_enc_state*ggml_type_size(GGML_TYPE_F16); + ctx_size += n_enc_layer*4*n_enc_state* ggml_type_size(GGML_TYPE_F32); + } + + ctx_size += (8 + 14*n_enc_layer)*ggml_tensor_overhead(); + + // prompt encoder + { + ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F16); // 2*(n_enc_out_chans/2) + + ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F32); + ctx_size += n_pt_embd*n_enc_out_chans*ggml_type_size(GGML_TYPE_F32); + } + + ctx_size += (2 + n_pt_embd)*ggml_tensor_overhead(); + + // mask decoder + { + //transformer + { + const int tfm_layers_count = 2; + const int qkv_count = 3; + const int norm_count = 4; + const int n_hypernet_mpls_count = 4; + + // self_attn + ctx_size += tfm_layers_count*qkv_count*n_enc_state*n_enc_state*ggml_type_size(GGML_TYPE_F16); + ctx_size += tfm_layers_count*qkv_count*n_enc_state* ggml_type_size(GGML_TYPE_F32); + ctx_size += tfm_layers_count*n_enc_state* ggml_type_size(GGML_TYPE_F32); + + // all norms + ctx_size += tfm_layers_count*norm_count*n_enc_state*ggml_type_size(GGML_TYPE_F32); + ctx_size += tfm_layers_count*norm_count*n_enc_state*ggml_type_size(GGML_TYPE_F32); + + // cross_attn_token_to_img + ctx_size += tfm_layers_count*qkv_count*n_enc_state*(n_enc_state/2)*ggml_type_size(GGML_TYPE_F16); + ctx_size += tfm_layers_count*qkv_count*(n_enc_state/2)* ggml_type_size(GGML_TYPE_F32); + ctx_size += tfm_layers_count*n_enc_state* ggml_type_size(GGML_TYPE_F32); + + // mlp + ctx_size += tfm_layers_count*8*n_enc_out_chans*n_enc_out_chans*ggml_type_size(GGML_TYPE_F16); + ctx_size += tfm_layers_count*8*n_enc_out_chans* ggml_type_size(GGML_TYPE_F32); + ctx_size += tfm_layers_count*n_enc_out_chans*8*n_enc_out_chans*ggml_type_size(GGML_TYPE_F16); + ctx_size += tfm_layers_count*n_enc_out_chans* ggml_type_size(GGML_TYPE_F32); + + // cross_attn_img_to_token + ctx_size += tfm_layers_count*qkv_count*n_enc_state*(n_enc_state/2)*ggml_type_size(GGML_TYPE_F16); + ctx_size += tfm_layers_count*qkv_count*(n_enc_state/2)* ggml_type_size(GGML_TYPE_F32); + ctx_size += tfm_layers_count*n_enc_state* ggml_type_size(GGML_TYPE_F32); + + // transformer_final_attn_token_to_img + ctx_size += qkv_count*n_enc_state*(n_enc_state/2)*ggml_type_size(GGML_TYPE_F16); + ctx_size += qkv_count*(n_enc_state/2)* ggml_type_size(GGML_TYPE_F32); + ctx_size += n_enc_state* ggml_type_size(GGML_TYPE_F32); + + // transformer_norm_final + ctx_size += norm_count*n_enc_state*ggml_type_size(GGML_TYPE_F32); + ctx_size += norm_count*n_enc_state*ggml_type_size(GGML_TYPE_F32); + + // output_upscaling + ctx_size += n_enc_out_chans*n_img_embd*2*2*ggml_type_size(GGML_TYPE_F16); + ctx_size += 3*n_img_embd* ggml_type_size(GGML_TYPE_F32); + ctx_size += n_enc_out_chans*n_img_embd*(n_img_embd/2)*2*2*ggml_type_size(GGML_TYPE_F16); + ctx_size += (n_img_embd/2)* ggml_type_size(GGML_TYPE_F32); + + // output_hypernetworks_mlps + ctx_size += n_hypernet_mpls_count*2*n_enc_out_chans*n_enc_out_chans*ggml_type_size(GGML_TYPE_F16); + ctx_size += n_hypernet_mpls_count*2*n_enc_out_chans* ggml_type_size(GGML_TYPE_F32); + ctx_size += n_hypernet_mpls_count*n_enc_out_chans*(n_img_embd/2)*ggml_type_size(GGML_TYPE_F16); + ctx_size += n_hypernet_mpls_count*(n_img_embd/2)* ggml_type_size(GGML_TYPE_F32); + + // iou_prediction_head + ctx_size += 2*n_enc_out_chans*n_enc_out_chans*ggml_type_size(GGML_TYPE_F16); + ctx_size += 2*n_enc_out_chans* ggml_type_size(GGML_TYPE_F32); + ctx_size += n_pt_embd*n_enc_out_chans*ggml_type_size(GGML_TYPE_F16); + ctx_size += n_pt_embd* ggml_type_size(GGML_TYPE_F32); + + // iou_token_w + ctx_size += n_enc_out_chans*ggml_type_size(GGML_TYPE_F32); + + // mask_tokens_w + ctx_size += n_pt_embd*n_enc_out_chans*ggml_type_size(GGML_TYPE_F32); + } + } + fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); + + return ctx_size; + }(); + + // create the ggml context + { + struct ggml_init_params params = { + /*.mem_size =*/ ctx_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ false, + }; + + ctx = ggml_init(params); + if (!ctx) { + fprintf(stderr, "%s: ggml_init() failed\n", __func__); + return false; + } + } + + // prepare memory for the weights + { + const auto & hparams = model.hparams; + + const int32_t n_enc_state = hparams.n_enc_state; + const int32_t n_enc_layer = hparams.n_enc_layer; + const int32_t n_enc_head_dim = hparams.n_enc_head_dim(); + const int32_t n_enc_out_chans = hparams.n_enc_out_chans; + const int32_t n_pt_embd = hparams.n_pt_embd; + + const int32_t n_img_embd = hparams.n_img_embd(); + const int32_t n_window_size = hparams.n_window_size(); + const int32_t n_patch_size = hparams.n_patch_size(); + + model.enc_img.layers.resize(n_enc_layer); + + // image encoder + { + auto & enc = model.enc_img; + + enc.pe = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_enc_state, n_img_embd, n_img_embd, 1); + + enc.proj_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, n_patch_size, n_patch_size, 3, n_enc_state); + enc.proj_b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 1, 1, n_enc_state); + + enc.neck_conv_0 = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, n_enc_state, n_enc_out_chans); + enc.neck_conv_1 = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, n_enc_out_chans, n_enc_out_chans); + + enc.neck_norm_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + enc.neck_norm_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + + enc.neck_norm_1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + enc.neck_norm_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + + model.tensors["image_encoder.pos_embed"] = enc.pe; + + model.tensors["image_encoder.patch_embed.proj.weight"] = enc.proj_w; + model.tensors["image_encoder.patch_embed.proj.bias"] = enc.proj_b; + + model.tensors["image_encoder.neck.0.weight"] = enc.neck_conv_0; + model.tensors["image_encoder.neck.2.weight"] = enc.neck_conv_1; + + model.tensors["image_encoder.neck.1.weight"] = enc.neck_norm_0_w; + model.tensors["image_encoder.neck.1.bias"] = enc.neck_norm_0_b; + + model.tensors["image_encoder.neck.3.weight"] = enc.neck_norm_1_w; + model.tensors["image_encoder.neck.3.bias"] = enc.neck_norm_1_b; + + for (int i = 0; i < n_enc_layer; ++i) { + auto & layer = enc.layers[i]; + + layer.norm1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state); + layer.norm1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state); + + if (hparams.is_global_attn(i)) { + layer.rel_pos_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_head_dim, 2*n_img_embd - 1); + layer.rel_pos_h = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_head_dim, 2*n_img_embd - 1); + } else { + layer.rel_pos_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_head_dim, 2*n_window_size - 1); + layer.rel_pos_h = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_head_dim, 2*n_window_size - 1); + } + + layer.qkv_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_state, 3*n_enc_state); + layer.qkv_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_enc_state); + + layer.proj_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_state, n_enc_state); + layer.proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state); + + layer.norm2_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state); + layer.norm2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state); + + layer.mlp_lin1_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_state, 4*n_enc_state); + layer.mlp_lin1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_enc_state); + + layer.mlp_lin2_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, 4*n_enc_state, n_enc_state); + layer.mlp_lin2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_state); + + model.tensors["image_encoder.blocks." + std::to_string(i) + ".norm1.weight"] = layer.norm1_w; + model.tensors["image_encoder.blocks." + std::to_string(i) + ".norm1.bias"] = layer.norm1_b; + + model.tensors["image_encoder.blocks." + std::to_string(i) + ".attn.rel_pos_w"] = layer.rel_pos_w; + model.tensors["image_encoder.blocks." + std::to_string(i) + ".attn.rel_pos_h"] = layer.rel_pos_h; + + model.tensors["image_encoder.blocks." + std::to_string(i) + ".attn.qkv.weight"] = layer.qkv_w; + model.tensors["image_encoder.blocks." + std::to_string(i) + ".attn.qkv.bias"] = layer.qkv_b; + + model.tensors["image_encoder.blocks." + std::to_string(i) + ".attn.proj.weight"] = layer.proj_w; + model.tensors["image_encoder.blocks." + std::to_string(i) + ".attn.proj.bias"] = layer.proj_b; + + model.tensors["image_encoder.blocks." + std::to_string(i) + ".norm2.weight"] = layer.norm2_w; + model.tensors["image_encoder.blocks." + std::to_string(i) + ".norm2.bias"] = layer.norm2_b; + + model.tensors["image_encoder.blocks." + std::to_string(i) + ".mlp.lin1.weight"] = layer.mlp_lin1_w; + model.tensors["image_encoder.blocks." + std::to_string(i) + ".mlp.lin1.bias"] = layer.mlp_lin1_b; + + model.tensors["image_encoder.blocks." + std::to_string(i) + ".mlp.lin2.weight"] = layer.mlp_lin2_w; + model.tensors["image_encoder.blocks." + std::to_string(i) + ".mlp.lin2.bias"] = layer.mlp_lin2_b; + } + } + + // prompt encoder + { + auto & enc = model.enc_prompt; + + enc.pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_enc_out_chans/2, 2); + + enc.not_a_pt_embd_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + enc.no_mask_embd_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + + model.tensors["prompt_encoder.pe_layer.positional_encoding_gaussian_matrix"] = enc.pe; + model.tensors["prompt_encoder.not_a_point_embed.weight"] = enc.not_a_pt_embd_w; + model.tensors["prompt_encoder.no_mask_embed.weight"] = enc.no_mask_embd_w; + + enc.pt_embd.resize(n_pt_embd); + for (int i = 0; i < n_pt_embd; i++) { + enc.pt_embd[i] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + + model.tensors["prompt_encoder.point_embeddings." + std::to_string(i) + ".weight"] = enc.pt_embd[i]; + } + } + + // mask decoder + { + auto & dec = model.dec; + auto & tfm_layers = dec.transformer_layers; + + const int tfm_layers_count = 2; + tfm_layers.resize(tfm_layers_count); + for (int i = 0; i < tfm_layers_count; ++i) { + auto& l = tfm_layers[i]; + l.self_attn.q_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans); + l.self_attn.q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + l.self_attn.k_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans); + l.self_attn.k_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + l.self_attn.v_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans); + l.self_attn.v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + l.self_attn.out_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans); + l.self_attn.out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + + l.norm1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + l.norm1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + + l.cross_attn_token_to_img.q_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2); + l.cross_attn_token_to_img.q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2); + l.cross_attn_token_to_img.k_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2); + l.cross_attn_token_to_img.k_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2); + l.cross_attn_token_to_img.v_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2); + l.cross_attn_token_to_img.v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2); + l.cross_attn_token_to_img.out_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans/2, n_enc_out_chans); + l.cross_attn_token_to_img.out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + + l.norm2_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + l.norm2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + + l.mlp_lin1_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, 8*n_enc_out_chans); + l.mlp_lin1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 8*n_enc_out_chans); + l.mlp_lin2_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, 8*n_enc_out_chans, n_enc_out_chans); + l.mlp_lin2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + + l.norm3_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + l.norm3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + + l.norm4_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + l.norm4_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + + l.cross_attn_img_to_token.q_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2); + l.cross_attn_img_to_token.q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2); + l.cross_attn_img_to_token.k_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2); + l.cross_attn_img_to_token.k_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2); + l.cross_attn_img_to_token.v_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2); + l.cross_attn_img_to_token.v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2); + l.cross_attn_img_to_token.out_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans/2, n_enc_out_chans); + l.cross_attn_img_to_token.out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + + const auto prefix = "mask_decoder.transformer.layers." + std::to_string(i) + "."; + model.tensors[prefix + "self_attn.q_proj.weight"] = l.self_attn.q_w; + model.tensors[prefix + "self_attn.q_proj.bias"] = l.self_attn.q_b; + model.tensors[prefix + "self_attn.k_proj.weight"] = l.self_attn.k_w; + model.tensors[prefix + "self_attn.k_proj.bias"] = l.self_attn.k_b; + model.tensors[prefix + "self_attn.v_proj.weight"] = l.self_attn.v_w; + model.tensors[prefix + "self_attn.v_proj.bias"] = l.self_attn.v_b; + model.tensors[prefix + "self_attn.out_proj.weight"] = l.self_attn.out_w; + model.tensors[prefix + "self_attn.out_proj.bias"] = l.self_attn.out_b; + + model.tensors[prefix + "norm1.weight"] = l.norm1_w; + model.tensors[prefix + "norm1.bias"] = l.norm1_b; + + model.tensors[prefix + "cross_attn_token_to_image.q_proj.weight"] = l.cross_attn_token_to_img.q_w; + model.tensors[prefix + "cross_attn_token_to_image.q_proj.bias"] = l.cross_attn_token_to_img.q_b; + model.tensors[prefix + "cross_attn_token_to_image.k_proj.weight"] = l.cross_attn_token_to_img.k_w; + model.tensors[prefix + "cross_attn_token_to_image.k_proj.bias"] = l.cross_attn_token_to_img.k_b; + model.tensors[prefix + "cross_attn_token_to_image.v_proj.weight"] = l.cross_attn_token_to_img.v_w; + model.tensors[prefix + "cross_attn_token_to_image.v_proj.bias"] = l.cross_attn_token_to_img.v_b; + model.tensors[prefix + "cross_attn_token_to_image.out_proj.weight"] = l.cross_attn_token_to_img.out_w; + model.tensors[prefix + "cross_attn_token_to_image.out_proj.bias"] = l.cross_attn_token_to_img.out_b; + + model.tensors[prefix + "norm2.weight"] = l.norm2_w; + model.tensors[prefix + "norm2.bias"] = l.norm2_b; + + model.tensors[prefix + "mlp.lin1.weight"] = l.mlp_lin1_w; + model.tensors[prefix + "mlp.lin1.bias"] = l.mlp_lin1_b; + model.tensors[prefix + "mlp.lin2.weight"] = l.mlp_lin2_w; + model.tensors[prefix + "mlp.lin2.bias"] = l.mlp_lin2_b; + + model.tensors[prefix + "norm3.weight"] = l.norm3_w; + model.tensors[prefix + "norm3.bias"] = l.norm3_b; + model.tensors[prefix + "norm4.weight"] = l.norm4_w; + model.tensors[prefix + "norm4.bias"] = l.norm4_b; + + model.tensors[prefix + "cross_attn_image_to_token.q_proj.weight"] = l.cross_attn_img_to_token.q_w; + model.tensors[prefix + "cross_attn_image_to_token.q_proj.bias"] = l.cross_attn_img_to_token.q_b; + model.tensors[prefix + "cross_attn_image_to_token.k_proj.weight"] = l.cross_attn_img_to_token.k_w; + model.tensors[prefix + "cross_attn_image_to_token.k_proj.bias"] = l.cross_attn_img_to_token.k_b; + model.tensors[prefix + "cross_attn_image_to_token.v_proj.weight"] = l.cross_attn_img_to_token.v_w; + model.tensors[prefix + "cross_attn_image_to_token.v_proj.bias"] = l.cross_attn_img_to_token.v_b; + model.tensors[prefix + "cross_attn_image_to_token.out_proj.weight"] = l.cross_attn_img_to_token.out_w; + model.tensors[prefix + "cross_attn_image_to_token.out_proj.bias"] = l.cross_attn_img_to_token.out_b; + } + + dec.transformer_final_attn_token_to_img.q_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2); + dec.transformer_final_attn_token_to_img.q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2); + dec.transformer_final_attn_token_to_img.k_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2); + dec.transformer_final_attn_token_to_img.k_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2); + dec.transformer_final_attn_token_to_img.v_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans/2); + dec.transformer_final_attn_token_to_img.v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans/2); + dec.transformer_final_attn_token_to_img.out_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans/2, n_enc_out_chans); + dec.transformer_final_attn_token_to_img.out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + + model.tensors["mask_decoder.transformer.final_attn_token_to_image.q_proj.weight"] = dec.transformer_final_attn_token_to_img.q_w; + model.tensors["mask_decoder.transformer.final_attn_token_to_image.q_proj.bias"] = dec.transformer_final_attn_token_to_img.q_b; + model.tensors["mask_decoder.transformer.final_attn_token_to_image.k_proj.weight"] = dec.transformer_final_attn_token_to_img.k_w; + model.tensors["mask_decoder.transformer.final_attn_token_to_image.k_proj.bias"] = dec.transformer_final_attn_token_to_img.k_b; + model.tensors["mask_decoder.transformer.final_attn_token_to_image.v_proj.weight"] = dec.transformer_final_attn_token_to_img.v_w; + model.tensors["mask_decoder.transformer.final_attn_token_to_image.v_proj.bias"] = dec.transformer_final_attn_token_to_img.v_b; + model.tensors["mask_decoder.transformer.final_attn_token_to_image.out_proj.weight"] = dec.transformer_final_attn_token_to_img.out_w; + model.tensors["mask_decoder.transformer.final_attn_token_to_image.out_proj.bias"] = dec.transformer_final_attn_token_to_img.out_b; + + dec.transformer_norm_final_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + dec.transformer_norm_final_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + + model.tensors["mask_decoder.transformer.norm_final_attn.weight"] = dec.transformer_norm_final_w; + model.tensors["mask_decoder.transformer.norm_final_attn.bias"] = dec.transformer_norm_final_b; + + dec.output_upscaling_0_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 2, 2, n_img_embd, n_enc_out_chans); + dec.output_upscaling_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_img_embd); + dec.output_upscaling_1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_img_embd); + dec.output_upscaling_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_img_embd); + dec.output_upscaling_3_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 2, 2, n_img_embd/2, n_img_embd); + dec.output_upscaling_3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_img_embd/2); + + model.tensors["mask_decoder.output_upscaling.0.weight"] = dec.output_upscaling_0_w; + model.tensors["mask_decoder.output_upscaling.0.bias"] = dec.output_upscaling_0_b; + model.tensors["mask_decoder.output_upscaling.1.weight"] = dec.output_upscaling_1_w; + model.tensors["mask_decoder.output_upscaling.1.bias"] = dec.output_upscaling_1_b; + model.tensors["mask_decoder.output_upscaling.3.weight"] = dec.output_upscaling_3_w; + model.tensors["mask_decoder.output_upscaling.3.bias"] = dec.output_upscaling_3_b; + + const int n_hypernet_mpls_count = 4; + dec.output_hypernet_mlps.resize(n_hypernet_mpls_count); + for (int i = 0; i < n_hypernet_mpls_count; ++i) { + auto& mlp = dec.output_hypernet_mlps[i]; + + mlp.w_0 = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans); + mlp.b_0 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + mlp.w_1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans); + mlp.b_1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + mlp.w_2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_img_embd/2); + mlp.b_2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_img_embd/2); + + const auto prefix = "mask_decoder.output_hypernetworks_mlps." + std::to_string(i) + "."; + model.tensors[prefix + "layers.0.weight"] = mlp.w_0; + model.tensors[prefix + "layers.0.bias"] = mlp.b_0; + model.tensors[prefix + "layers.1.weight"] = mlp.w_1; + model.tensors[prefix + "layers.1.bias"] = mlp.b_1; + model.tensors[prefix + "layers.2.weight"] = mlp.w_2; + model.tensors[prefix + "layers.2.bias"] = mlp.b_2; + } + + dec.iou_prediction_head_0_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans); + dec.iou_prediction_head_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + dec.iou_prediction_head_1_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_enc_out_chans); + dec.iou_prediction_head_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_enc_out_chans); + dec.iou_prediction_head_2_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, n_enc_out_chans, n_pt_embd); + dec.iou_prediction_head_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pt_embd); + + dec.iou_token_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_enc_out_chans, 1); + dec.mask_tokens_w = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_enc_out_chans, n_pt_embd); + + model.tensors["mask_decoder.iou_prediction_head.layers.0.weight"] = dec.iou_prediction_head_0_w; + model.tensors["mask_decoder.iou_prediction_head.layers.0.bias"] = dec.iou_prediction_head_0_b; + model.tensors["mask_decoder.iou_prediction_head.layers.1.weight"] = dec.iou_prediction_head_1_w; + model.tensors["mask_decoder.iou_prediction_head.layers.1.bias"] = dec.iou_prediction_head_1_b; + model.tensors["mask_decoder.iou_prediction_head.layers.2.weight"] = dec.iou_prediction_head_2_w; + model.tensors["mask_decoder.iou_prediction_head.layers.2.bias"] = dec.iou_prediction_head_2_b; + + model.tensors["mask_decoder.iou_token.weight"] = dec.iou_token_w; + model.tensors["mask_decoder.mask_tokens.weight"] = dec.mask_tokens_w; + } + } + + // load weights + { + int n_tensors = 0; + size_t total_size = 0; + + fprintf(stderr, "%s: ", __func__); + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ftype; + + fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + fin.read(reinterpret_cast(&length), sizeof(length)); + fin.read(reinterpret_cast(&ftype), sizeof(ftype)); + + if (fin.eof()) { + break; + } + + int64_t nelements = 1; + int64_t ne[4] = { 1, 1, 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + int32_t ne_cur; + fin.read(reinterpret_cast(&ne_cur), sizeof(ne_cur)); + ne[i] = ne_cur; + nelements *= ne[i]; + } + + std::string name(length, 0); + fin.read(&name[0], length); + + if (model.tensors.find(name.data()) == model.tensors.end()) { + fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + auto tensor = model.tensors[name.data()]; + //printf("ne0 = %jd, ne1 = %jd, ne2 = %jd, ne3 = %jd\n", ne[0], ne[1], ne[2], ne[3]); + + if (ggml_nelements(tensor) != nelements) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %d, expected %d\n", + __func__, name.data(), (int) nelements, (int) ggml_nelements(tensor)); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2] || tensor->ne[3] != ne[3]) { + fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d, %d], expected [%d, %d, %d, %d]\n", + __func__, name.data(), + (int) ne[0], (int) ne[1], (int) ne[2], (int) ne[3], + (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], (int) tensor->ne[3]); + return false; + } + + size_t bpe = 0; + + switch (ftype) { + case 0: bpe = ggml_type_size(GGML_TYPE_F32); break; + case 1: bpe = ggml_type_size(GGML_TYPE_F16); break; + case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break; + case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break; + default: + { + fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype); + return false; + } + }; + + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), (size_t) nelements*bpe); + return false; + } + + fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); + + total_size += ggml_nbytes(tensor); + if (++n_tensors % 8 == 0) { + fprintf(stderr, "."); + fflush(stdout); + } + } + + if (n_tensors != int(model.tensors.size())) { + fprintf(stderr, "%s: model file has %d tensors, but %d tensors were expected\n", __func__, n_tensors, (int) model.tensors.size()); + return false; + } + + fprintf(stderr, " done\n"); + + fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors); + } + + fin.close(); + + return true; +} + +struct ggml_tensor * sam_fill_dense_pe( + const sam_model & model, + struct ggml_context * ctx0, + struct ggml_cgraph * gf, + sam_state & state) { + const auto & hparams = model.hparams; + const auto & enc = model.enc_prompt; + + + const int32_t n_img_embd = hparams.n_img_embd(); + struct ggml_tensor * xy_embed_stacked = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 2, n_img_embd, n_img_embd); + ggml_set_name(xy_embed_stacked, "xy_embed_stacked"); + ggml_set_input(xy_embed_stacked); + + struct ggml_tensor * cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, enc.pe)), xy_embed_stacked); + + cur = ggml_scale(ctx0, cur, float(2.0*M_PI)); + + // concat + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L192 + { + struct ggml_tensor * t_sin = ggml_map_custom1(ctx0, cur, ggml_sam_sin, GGML_N_TASKS_MAX, NULL); + struct ggml_tensor * t_cos = ggml_map_custom1(ctx0, cur, ggml_sam_cos, GGML_N_TASKS_MAX, NULL); + + cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, t_sin->ne[0] + t_cos->ne[0], cur->ne[1], cur->ne[2]); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_sin, ggml_view_3d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], t_sin->ne[2], cur->nb[1], cur->nb[2], 0))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_cos, ggml_view_3d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], t_sin->ne[2], cur->nb[1], cur->nb[2], t_sin->nb[1]))); + } + + struct ggml_tensor * pe_img_dense = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); + ggml_build_forward_expand(gf, pe_img_dense); + + return pe_img_dense; +} + +struct ggml_tensor* sam_layer_norm_2d( + struct ggml_context * ctx0, + struct ggml_tensor * layer, + int n_channels, + struct ggml_tensor * w, + struct ggml_tensor * b, + float eps) { + // LayerNorm2d + // normalize along channel dimmension + // TODO: better implementation + layer = ggml_permute(ctx0, + ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, layer, 1, 2, 0, 3)), eps), + 2, 0, 1, 3); + + layer = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, ggml_reshape_3d(ctx0, w, 1, 1, n_channels), layer), + layer), + ggml_repeat(ctx0, ggml_reshape_3d(ctx0, b, 1, 1, n_channels), layer)); + + return layer; +} + +struct ggml_cgraph * sam_encode_image( + const sam_model & model, + sam_state & state, + const sam_image_f32 & img) { + + const auto & hparams = model.hparams; + const auto & enc = model.enc_img; + + const int32_t n_enc_state = hparams.n_enc_state; + const int32_t n_enc_layer = hparams.n_enc_layer; + const int32_t n_enc_head = hparams.n_enc_head; + const int32_t n_enc_head_dim = hparams.n_enc_head_dim(); + const int32_t n_enc_out_chans = hparams.n_enc_out_chans; + const int32_t n_img_size = hparams.n_img_size(); + const int32_t n_window_size = hparams.n_window_size(); + + struct ggml_init_params ggml_params = { + /*.mem_size =*/ state.buf_compute_img_enc.size(), + /*.mem_buffer =*/ state.buf_compute_img_enc.data(), + /*.no_alloc =*/ true, // skip allocating as we use ggml_alloc to allocate exact memory requirements + }; + + struct ggml_context * ctx0 = ggml_init(ggml_params); + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * inp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_img_size, n_img_size, 3, 1); + ggml_set_name(inp, "inp"); + ggml_set_input(inp); + + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L392 + struct ggml_tensor * cur = ggml_conv_2d_sk_p0(ctx0, enc.proj_w, inp); + cur = ggml_add_inplace(ctx0, + cur, + ggml_repeat(ctx0, enc.proj_b, cur)); + + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L394 + // keep in F32 + cur = ggml_cont(ctx0, + ggml_permute(ctx0, cur, 1, 2, 0, 3)); + + // convert to F16 + //cur = ggml_cpy(ctx0, + // ggml_permute(ctx0, cur, 1, 2, 0, 3), + // ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_enc_state, n_img_embd, n_img_embd)); + + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L108-L109 + cur = ggml_add_inplace(ctx0, cur, enc.pe); + + struct ggml_tensor * inpL = cur; + + for (int il = 0; il < n_enc_layer; ++il) { + const auto & layer = enc.layers[il]; + + // norm + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L168 + { + cur = ggml_norm(ctx0, inpL, hparams.eps); + + // cur = ln_0_w*cur + ln_0_b + cur = ggml_mul(ctx0, cur, layer.norm1_w); + cur = ggml_add_inplace(ctx0, cur, layer.norm1_b); + } + + const int64_t w0 = cur->ne[1]; + const int64_t h0 = cur->ne[2]; + + if (hparams.is_global_attn(il) == false) { + // local attention layer - apply window partition + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L169-L172 + cur = ggml_win_part(ctx0, cur, n_window_size); + } + + const int64_t W = cur->ne[1]; + const int64_t H = cur->ne[2]; + + // self-attention + { + cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); + cur = ggml_add_inplace(ctx0, cur, layer.qkv_b); + + // split qkv into separate tensors + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L225-L229 + const int B = cur->ne[3]; + + cur = ggml_reshape_4d(ctx0, cur, n_enc_state, 3, W*H, B); + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 3, 1, 2)); + + struct ggml_tensor * Q; + struct ggml_tensor * K; + struct ggml_tensor * V; + + Q = ggml_view_3d (ctx0, cur, n_enc_state, W*H, B, cur->nb[1], cur->nb[2], 0*cur->nb[3]); + Q = ggml_reshape_4d(ctx0, Q, n_enc_head_dim, n_enc_head, W*H, B); + Q = ggml_cont (ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); + Q = ggml_reshape_3d(ctx0, Q, n_enc_head_dim, W*H, B*n_enc_head); + + K = ggml_view_3d (ctx0, cur, n_enc_state, W*H, B, cur->nb[1], cur->nb[2], 1*cur->nb[3]); + K = ggml_reshape_4d(ctx0, K, n_enc_head_dim, n_enc_head, W*H, B); + K = ggml_cont (ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); + K = ggml_reshape_3d(ctx0, K, n_enc_head_dim, W*H, B*n_enc_head); + + V = ggml_view_3d (ctx0, cur, n_enc_state, W*H, B, cur->nb[1], cur->nb[2], 2*cur->nb[3]); + V = ggml_reshape_4d(ctx0, V, n_enc_head_dim, n_enc_head, W*H, B); + V = ggml_cont (ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); // transposed + V = ggml_reshape_3d(ctx0, V, W*H, n_enc_head_dim, B*n_enc_head); + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + struct ggml_tensor * KQ_scaled = + ggml_scale_inplace(ctx0, + KQ, + 1.0f/sqrtf(n_enc_head_dim)); + + struct ggml_tensor * rw = ggml_get_rel_pos(ctx0, layer.rel_pos_w, W, W); + struct ggml_tensor * rh = ggml_get_rel_pos(ctx0, layer.rel_pos_h, H, H); + + struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Q, n_enc_head_dim, W, H, B*n_enc_head); + + struct ggml_tensor * rel_w = ggml_cont(ctx0, ggml_permute(ctx0, + ggml_mul_mat(ctx0, + rw, + ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))), + 0, 2, 1, 3)); + struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r); + + struct ggml_tensor * attn = ggml_add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + + cur = + ggml_reshape_4d(ctx0, + ggml_cont(ctx0, + ggml_permute(ctx0, + ggml_reshape_4d(ctx0, KQV, n_enc_head_dim, W*H, n_enc_head, B), + 0, 2, 1, 3)), + n_enc_state, W, H, B); + + cur = ggml_mul_mat(ctx0, layer.proj_w, cur); + cur = ggml_add_inplace(ctx0, cur, layer.proj_b); + } + + if (hparams.is_global_attn(il) == false) { + // local attention layer - reverse window partition + cur = ggml_win_unpart(ctx0, cur, w0, h0, n_window_size); + } + + cur = ggml_add_inplace(ctx0, cur, inpL); + + struct ggml_tensor * inpFF = cur; + + // feed-forward network + { + // norm + { + cur = ggml_norm(ctx0, inpFF, hparams.eps); + + // cur = mlp_ln_w*cur + mlp_ln_b + cur = ggml_mul(ctx0, cur, layer.norm2_w); + cur = ggml_add_inplace(ctx0, cur, layer.norm2_b); + } + + // fully connected + cur = ggml_mul_mat(ctx0, layer.mlp_lin1_w, cur); + cur = ggml_add_inplace(ctx0, cur, layer.mlp_lin1_b); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + // projection + cur = ggml_mul_mat(ctx0, layer.mlp_lin2_w, cur); + cur = ggml_add_inplace(ctx0, cur, layer.mlp_lin2_b); + } + + inpL = ggml_add(ctx0, cur, inpFF); + } + + cur = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 2, 0, 1, 3)); + + cur = ggml_conv_2d_sk_p0(ctx0, enc.neck_conv_0, cur); + + cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_0_w, enc.neck_norm_0_b, hparams.eps); + + cur = ggml_conv_2d_s1_ph(ctx0, enc.neck_conv_1, cur); + + cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_1_w, enc.neck_norm_1_b, hparams.eps); + + cur = ggml_cpy(ctx0, cur, state.embd_img); + + ggml_build_forward_expand(gf, cur); + ggml_disconnect_node_from_graph(state.embd_img); + + //ggml_graph_print(&gf); + + ggml_free(ctx0); + + ggml_gallocr_alloc_graph(state.allocr, gf); + + { + struct ggml_tensor * inp = ggml_graph_get_tensor(gf, "inp"); + float * data = (float *) ggml_get_data(inp); + + const int nx = img.nx; + const int ny = img.ny; + const int n = nx*ny; + + GGML_ASSERT(nx == n_img_size && ny == n_img_size); + + for (int k = 0; k < 3; k++) { + for (int y = 0; y < ny; y++) { + for (int x = 0; x < nx; x++) { + data[k*n + y*nx + x] = img.data[3*(y*nx + x) + k]; + } + } + } + } + + return gf; +} + + +struct prompt_encoder_result { + struct ggml_tensor * embd_prompt_sparse = {}; + struct ggml_tensor * embd_prompt_dense = {}; +}; + +struct ggml_tensor * sam_prompt_encode_pe_encoding( + const sam_encoder_prompt & enc, + struct ggml_context * ctx0, + struct ggml_cgraph * gf, + struct ggml_tensor * coords) { + + auto * cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, enc.pe)), coords); + cur = ggml_scale(ctx0, cur, float(2.0*M_PI)); + + // concat + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L192 + { + struct ggml_tensor * t_sin = ggml_map_custom1(ctx0, cur, ggml_sam_sin, GGML_N_TASKS_MAX, NULL); + struct ggml_tensor * t_cos = ggml_map_custom1(ctx0, cur, ggml_sam_cos, GGML_N_TASKS_MAX, NULL); + + cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, t_sin->ne[0] + t_cos->ne[0], cur->ne[1]); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_sin, ggml_view_2d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], cur->nb[1], 0))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_cos, ggml_view_2d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], cur->nb[1], t_sin->nb[1]))); + + } + return cur; + +} +// encode a prompt +// +// - points +// - boxes +// - masks +// +// TODO: currently just encode a single point for simplicity +// +prompt_encoder_result sam_encode_prompt( + const sam_model & model, + struct ggml_context * ctx0, + struct ggml_cgraph * gf, + sam_state & state, + const sam_prompt & prompt) { + + const auto & hparams = model.hparams; + const auto & enc = model.enc_prompt; + + struct ggml_tensor * inp = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2, 2); + ggml_set_name(inp, "prompt_input"); + ggml_set_input(inp); + + auto * embd_prompt_sparse = [&]() { + switch (prompt.prompt_type) { + case SAM_PROMPT_TYPE_POINT: { + // PromptEncoder._embed_points + auto * pt_embd = sam_prompt_encode_pe_encoding(enc, ctx0, gf, inp); + + // overwrite label == -1 with not_a_point_embed.weight + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L86 + // TODO: extend for multiple points + auto * pt_embd_not = ggml_view_2d(ctx0, pt_embd, pt_embd->ne[0], 1, pt_embd->nb[1], pt_embd->nb[1]); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, enc.not_a_pt_embd_w, pt_embd_not)); + + // add point_embeddings[1] to label == 1 + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L90 + auto * pt_embd1 = ggml_view_2d(ctx0, pt_embd, pt_embd->ne[0], 1, pt_embd->nb[1], 0); + ggml_build_forward_expand(gf, ggml_add_inplace(ctx0, pt_embd1, enc.pt_embd[1])); + + return pt_embd; + } break; + case SAM_PROMPT_TYPE_BOX: { + // PromptEncoder._embed_boxes + auto * corner_embd = sam_prompt_encode_pe_encoding(enc, ctx0, gf, inp); + + // corner_embd[:, 0, :] += self.point_embeddings[2].weight + // corner_embd[:, 1, :] += self.point_embeddings[3].weight + auto * corner0 = ggml_view_2d( + ctx0, corner_embd, corner_embd->ne[0], 1, corner_embd->nb[1], 0); + auto * corner1 = ggml_view_2d( + ctx0, corner_embd, corner_embd->ne[0], 1, corner_embd->nb[1], corner_embd->nb[1]); + + ggml_build_forward_expand(gf, ggml_add_inplace(ctx0, corner0, enc.pt_embd[2])); + ggml_build_forward_expand(gf, ggml_add_inplace(ctx0, corner1, enc.pt_embd[3])); + + return corner_embd; + } break; + } + }(); + + ggml_build_forward_expand(gf, embd_prompt_sparse); + + struct ggml_tensor * embd_prompt_dense = ggml_repeat(ctx0, + ggml_cont(ctx0, + ggml_view_3d(ctx0, enc.no_mask_embd_w, + 1, 1, enc.no_mask_embd_w->ne[0], enc.no_mask_embd_w->nb[0], enc.no_mask_embd_w->nb[0], 0)), + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hparams.n_img_embd(), hparams.n_img_embd(), hparams.n_enc_out_chans)); + + ggml_build_forward_expand(gf, embd_prompt_dense); + + //printf("used_mem = %zu\n", ggml_used_mem(ctx0)); + + prompt_encoder_result res; + res.embd_prompt_sparse = embd_prompt_sparse; + res.embd_prompt_dense = embd_prompt_dense; + return res; +} + +struct ggml_tensor* sam_decode_mask_transformer_attn( + const sam_layer_dec_transformer_attn & attn, + struct ggml_tensor * queries, + struct ggml_tensor * keys, + struct ggml_tensor * values, + struct ggml_context * ctx0, + const sam_model & model) { + const auto & hparams = model.hparams; + const int n_head = hparams.n_dec_heads; + + struct ggml_tensor * Qcur = {}; + struct ggml_tensor * Kcur = {}; + struct ggml_tensor * Vcur = {}; + + Qcur = ggml_mul_mat(ctx0, attn.q_w, queries); + Qcur = ggml_add_inplace(ctx0, Qcur, attn.q_b); + + Kcur = ggml_mul_mat(ctx0, attn.k_w, keys); + Kcur = ggml_add_inplace(ctx0, Kcur, attn.k_b); + + Vcur = ggml_mul_mat(ctx0, attn.v_w, values); + Vcur = ggml_add_inplace(ctx0, Vcur, attn.v_b); + + struct ggml_tensor * Q = {}; + struct ggml_tensor * K = {}; + struct ggml_tensor * V = {}; + + Q = ggml_reshape_4d(ctx0, Qcur, Qcur->ne[0]/n_head, n_head, Qcur->ne[1], Qcur->ne[2]); + Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); + + K = ggml_reshape_4d(ctx0, Kcur, Kcur->ne[0]/n_head, n_head, Kcur->ne[1], Kcur->ne[2]); + K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); + + V = ggml_reshape_4d(ctx0, Vcur, Vcur->ne[0]/n_head, n_head, Vcur->ne[1], Vcur->ne[2]); + V = ggml_cont(ctx0, ggml_permute(ctx0, V, 0, 2, 1, 3)); + + // Q * K + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f/sqrt(float(Q->ne[0]))); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_scaled); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, KQ_soft_max, ggml_cont(ctx0, ggml_transpose(ctx0, V))); + + struct ggml_tensor * KQV_merged = ggml_cont(ctx0, ggml_transpose(ctx0, KQV)); + KQV_merged = ggml_cont(ctx0, ggml_permute(ctx0, KQV_merged, 0, 2, 1, 3)); + KQV_merged = ggml_reshape_3d(ctx0, KQV_merged, KQV_merged->ne[0]*KQV_merged->ne[1], KQV_merged->ne[2], KQV_merged->ne[3]); + KQV_merged = ggml_mul_mat(ctx0, attn.out_w, KQV_merged); + KQV_merged = ggml_add_inplace(ctx0, KQV_merged, attn.out_b); + + return KQV_merged; +} + +struct ggml_tensor * sam_decode_mask_mlp_relu_3( + struct ggml_tensor * in, + struct ggml_tensor * w_0, + struct ggml_tensor * b_0, + struct ggml_tensor * w_1, + struct ggml_tensor * b_1, + struct ggml_tensor * w_2, + struct ggml_tensor * b_2, + struct ggml_context * ctx0) { + + struct ggml_tensor * cur = {}; + cur = ggml_mul_mat(ctx0, w_0, in); + cur = ggml_add_inplace(ctx0, cur, b_0); + + cur = ggml_relu_inplace(ctx0, cur); + + cur = ggml_mul_mat(ctx0, w_1, cur); + cur = ggml_add_inplace(ctx0, cur, b_1); + + cur = ggml_relu_inplace(ctx0, cur); + + cur = ggml_mul_mat(ctx0, w_2, cur); + cur = ggml_add_inplace(ctx0, cur, b_2); + + return cur; +} + +bool sam_decode_mask( + const sam_model & model, + const prompt_encoder_result & prompt, + struct ggml_tensor * pe_img, + struct ggml_context * ctx0, + struct ggml_cgraph * gf, + sam_state & state, + const bool multimask_output) { + + const auto & hparams = model.hparams; + const auto & dec = model.dec; + const int n_img_embd = hparams.n_img_embd(); + + struct ggml_tensor * tokens = {}; + { + // Concatenate output tokens + // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L120 + const auto& sparse = prompt.embd_prompt_sparse; + + tokens = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, dec.iou_token_w->ne[0], dec.iou_token_w->ne[1] + dec.mask_tokens_w->ne[1] + sparse->ne[1], sparse->ne[2]); + + const size_t offsets[3] = { 0, dec.iou_token_w->ne[1]*tokens->nb[1], dec.iou_token_w->ne[1]*tokens->nb[1] + dec.mask_tokens_w->ne[1]*tokens->nb[1] }; + ggml_build_forward_expand(gf, ggml_cpy(ctx0, dec.iou_token_w, ggml_view_2d(ctx0, tokens, tokens->ne[0], dec.iou_token_w->ne[1], tokens->nb[1], offsets[0]))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, dec.mask_tokens_w, ggml_view_2d(ctx0, tokens, tokens->ne[0], dec.mask_tokens_w->ne[1], tokens->nb[1], offsets[1]))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, sparse, ggml_view_2d(ctx0, tokens, tokens->ne[0], sparse->ne[1], tokens->nb[1], offsets[2]))); + // TODO: Sparse prompt embeddings can have more than one point + } + + + struct ggml_tensor * src = {}; + struct ggml_tensor * pos_src = {}; + int srcNE[4] = { 0, 0, 0, 0 }; + { + // Expand per-image data in the batch direction to be per-mask + // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L125 + src = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, state.embd_img->ne[0], state.embd_img->ne[1], state.embd_img->ne[2], tokens->ne[2]); + + src = ggml_add(ctx0, + ggml_repeat(ctx0, + state.embd_img, + src), + prompt.embd_prompt_dense); + + srcNE[0] = src->ne[0]; + srcNE[1] = src->ne[1]; + srcNE[2] = src->ne[2]; + srcNE[3] = src->ne[3]; + + // flatten & permute + // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L83 + src = ggml_cont(ctx0, ggml_permute(ctx0, + ggml_view_3d(ctx0, + src, + src->ne[0]*src->ne[1], + src->ne[2], + src->ne[3], + src->nb[2], + src->nb[3], + 0), + 1, 0, 2, 3)); + + pos_src = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, pe_img->ne[0], pe_img->ne[1], pe_img->ne[2], tokens->ne[2]); + pos_src = ggml_repeat(ctx0, + pe_img, + pos_src); + + // flatten & permute + // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L83 + pos_src = ggml_cont(ctx0, ggml_permute(ctx0, + ggml_view_3d(ctx0, + pos_src, + pos_src->ne[0]*pos_src->ne[1], + pos_src->ne[2], + pos_src->ne[3], + pos_src->nb[2], + pos_src->nb[3], + 0), + 1, 0, 2, 3)); + } + + struct ggml_tensor * queries = tokens; + struct ggml_tensor * keys = src; + { + // Run the transformer + // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L62 + for (int i = 0; i < int(model.dec.transformer_layers.size()); ++i) { + const auto& tfm_layer = model.dec.transformer_layers[i]; + + // Self attention block + // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L154 + const bool skip_first_layer_pe = i == 0; + if (skip_first_layer_pe) { + queries = sam_decode_mask_transformer_attn(tfm_layer.self_attn, queries, queries, queries, ctx0, model); + } + else { + struct ggml_tensor * q_0 = ggml_add(ctx0, queries, tokens); + + struct ggml_tensor * self_attn = sam_decode_mask_transformer_attn(tfm_layer.self_attn, q_0, q_0, queries, ctx0, model); + queries = ggml_add(ctx0, queries, self_attn); + } + + queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer); + queries = ggml_add_inplace(ctx0, + ggml_mul(ctx0, queries, tfm_layer.norm1_w), + tfm_layer.norm1_b); + + // Cross attention block, tokens attending to image embedding + // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L163 + struct ggml_tensor * q_1 = ggml_add(ctx0, queries, tokens); + struct ggml_tensor * k_1 = ggml_add(ctx0, keys, pos_src); + + struct ggml_tensor * cross_attn_token_to_img = sam_decode_mask_transformer_attn(tfm_layer.cross_attn_token_to_img, q_1, k_1, keys, ctx0, model); + + queries = ggml_add_inplace(ctx0, queries, cross_attn_token_to_img); + queries = ggml_norm_inplace(ctx0, queries, hparams.eps_decoder_transformer); + queries = ggml_add_inplace(ctx0, + ggml_mul(ctx0, queries, tfm_layer.norm2_w), + tfm_layer.norm2_b); + + // MLP block + // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L170 + struct ggml_tensor * mlp_out = ggml_mul_mat(ctx0, + tfm_layer.mlp_lin1_w, + queries); + + mlp_out = ggml_add_inplace(ctx0, mlp_out, tfm_layer.mlp_lin1_b); + + // RELU activation + mlp_out = ggml_relu_inplace(ctx0, mlp_out); + mlp_out = ggml_mul_mat(ctx0, tfm_layer.mlp_lin2_w, mlp_out); + + mlp_out = ggml_add_inplace(ctx0, mlp_out, tfm_layer.mlp_lin2_b); + + queries = ggml_add_inplace(ctx0, queries, mlp_out); + queries = ggml_norm_inplace(ctx0, queries, hparams.eps_decoder_transformer); + queries = ggml_add_inplace(ctx0, + ggml_mul(ctx0, queries, tfm_layer.norm3_w), + tfm_layer.norm3_b); + + // Cross attention block, image embedding attending to tokens + // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L175 + struct ggml_tensor * q_2 = ggml_add(ctx0, queries, tokens); + struct ggml_tensor * k_2 = ggml_add(ctx0, keys, pos_src); + + struct ggml_tensor * cross_attn_img_to_token = sam_decode_mask_transformer_attn(tfm_layer.cross_attn_img_to_token, k_2, q_2, queries, ctx0, model); + keys = ggml_add_inplace(ctx0, keys, cross_attn_img_to_token); + keys = ggml_norm_inplace(ctx0, keys, hparams.eps_decoder_transformer); + keys = ggml_add_inplace(ctx0, + ggml_mul(ctx0, keys, tfm_layer.norm4_w), + tfm_layer.norm4_b); + } + + // Apply the final attention layer from the points to the image + // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L99 + struct ggml_tensor * q = ggml_add(ctx0, queries, tokens); + struct ggml_tensor * k = ggml_add(ctx0, keys, pos_src); + + struct ggml_tensor * final_attn_token_to_img = sam_decode_mask_transformer_attn(dec.transformer_final_attn_token_to_img, q, k, keys, ctx0, model); + + queries = ggml_add_inplace(ctx0, queries, final_attn_token_to_img); + queries = ggml_norm_inplace(ctx0, queries, hparams.eps_decoder_transformer); + queries = ggml_add_inplace(ctx0, + ggml_mul(ctx0, queries, dec.transformer_norm_final_w), + dec.transformer_norm_final_b); + } + + + struct ggml_tensor * iou_pred = ggml_view_2d(ctx0, queries, queries->ne[0], queries->ne[2], queries->nb[2], 0); + const int num_mask_tokens = 4; // num_multimask_outputs + 1 + struct ggml_tensor * mask_tokens_out = ggml_view_3d(ctx0, queries, queries->ne[0], num_mask_tokens, queries->ne[2], queries->nb[1], num_mask_tokens*queries->nb[1], queries->nb[1]); + + // Upscale mask embeddings and predict masks using the mask tokens + // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L136 + keys = ggml_cont(ctx0, ggml_transpose(ctx0, keys)); + keys = ggml_view_4d(ctx0, keys, srcNE[0], srcNE[1], srcNE[2], srcNE[3], srcNE[0]*keys->nb[0], keys->nb[1], keys->nb[2], 0); + // ggml_build_forward_expand(gf, keys); + struct ggml_tensor * upscaled_embedding = {}; + { + // ConvTranspose2d + keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_0_w, keys, 2); + keys = ggml_add_inplace(ctx0, keys, ggml_repeat(ctx0, + ggml_reshape_3d(ctx0, dec.output_upscaling_0_b, 1, 1, dec.output_upscaling_0_b->ne[0]), + keys)); + + keys = sam_layer_norm_2d(ctx0, keys, n_img_embd, dec.output_upscaling_1_w, dec.output_upscaling_1_b, hparams.eps); + + // GELU activation + keys = ggml_gelu_inplace(ctx0, keys); + + // ConvTranspose2d + keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_3_w, keys, 2); + keys = ggml_add_inplace(ctx0, ggml_repeat(ctx0, + ggml_reshape_3d(ctx0, dec.output_upscaling_3_b, 1, 1, dec.output_upscaling_3_b->ne[0]), + keys), keys); + // GELU activation + keys = ggml_gelu_inplace(ctx0, keys); + upscaled_embedding = ggml_reshape_3d(ctx0, keys, keys->ne[0]*keys->ne[1], keys->ne[2], keys->ne[3]); + upscaled_embedding = ggml_cont(ctx0, ggml_transpose(ctx0, upscaled_embedding)); // TODO: Shouldn't be needed + } + + struct ggml_tensor * hyper_in = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_img_embd/2, num_mask_tokens, mask_tokens_out->ne[2]); + + for (int i = 0; i < num_mask_tokens; ++i) { + const auto& mlp = dec.output_hypernet_mlps[i]; + struct ggml_tensor * in = ggml_view_2d(ctx0, mask_tokens_out, mask_tokens_out->ne[0], mask_tokens_out->ne[2], mask_tokens_out->nb[1], i*mask_tokens_out->nb[1]); + struct ggml_tensor * out = sam_decode_mask_mlp_relu_3(in, mlp.w_0, mlp.b_0, mlp.w_1, mlp.b_1, mlp.w_2, mlp.b_2, ctx0); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, out, ggml_view_2d(ctx0, hyper_in, hyper_in->ne[0], hyper_in->ne[2], hyper_in->nb[1], i*hyper_in->nb[1]))); + } + + struct ggml_tensor * masks = ggml_mul_mat(ctx0, hyper_in, upscaled_embedding); + masks = ggml_cont(ctx0, ggml_transpose(ctx0, masks)); // TODO: Shouldn't be needed + masks = ggml_reshape_4d(ctx0, masks, keys->ne[0], keys->ne[1], masks->ne[1], keys->ne[3]); + + // Generate mask quality predictions + // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L146 + iou_pred = sam_decode_mask_mlp_relu_3(iou_pred, dec.iou_prediction_head_0_w, dec.iou_prediction_head_0_b, dec.iou_prediction_head_1_w, dec.iou_prediction_head_1_b, dec.iou_prediction_head_2_w, dec.iou_prediction_head_2_b, ctx0); + + // Select the correct mask or masks for output + // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L101 + if (multimask_output) { + iou_pred = ggml_cpy(state.ctx, ggml_view_1d(ctx0, iou_pred, iou_pred->ne[0] - 1, iou_pred->nb[0]), state.iou_predictions); + masks = ggml_view_4d(ctx0, masks, masks->ne[0], masks->ne[1], masks->ne[2] - 1, masks->ne[3], + masks->nb[1], masks->nb[2], masks->nb[3], masks->nb[2] /* offset*/); + masks = ggml_cpy(state.ctx, masks, state.low_res_masks); + } else { + iou_pred = ggml_cpy(state.ctx, ggml_view_1d(ctx0, iou_pred, 1, 0), ggml_view_1d(ctx0, state.iou_predictions, 1, 0)); + masks = ggml_view_4d(ctx0, masks, masks->ne[0], masks->ne[1], 1, masks->ne[3], + masks->nb[1], masks->nb[2], masks->nb[3], 0); + auto * low_res_mask = ggml_view_4d(ctx0, state.low_res_masks, masks->ne[0], masks->ne[1], 1, masks->ne[3], + masks->nb[1], masks->nb[2], masks->nb[3], 0); + masks = ggml_cpy(state.ctx, masks, low_res_mask); + } + + ggml_build_forward_expand(gf, masks); + ggml_build_forward_expand(gf, iou_pred); + + ggml_disconnect_node_from_graph(state.low_res_masks); + ggml_disconnect_node_from_graph(state.iou_predictions); + + return true; +} + +bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state & state, const std::string & fname, const bool multimask_output) { + if (state.low_res_masks->ne[2] == 0) return true; + if (state.low_res_masks->ne[2] != state.iou_predictions->ne[0]) { + printf("Error: number of masks (%d) does not match number of iou predictions (%d)\n", (int)state.low_res_masks->ne[2], (int)state.iou_predictions->ne[0]); + return false; + } + + const int n_img_size = hparams.n_img_size(); + const float mask_threshold = hparams.mask_threshold; + const float iou_threshold = hparams.iou_threshold; + const float stability_score_threshold = hparams.stability_score_threshold; + const float intersection_threshold = mask_threshold + hparams.stability_score_offset; + const float union_threshold = mask_threshold - hparams.stability_score_offset; + + const int ne0 = state.low_res_masks->ne[0]; + const int ne1 = state.low_res_masks->ne[1]; + const int ne2 = multimask_output ? state.low_res_masks->ne[2] : 1; + + // Remove padding and upscale masks to the original image size. + // ref: https://github.com/facebookresearch/segment-anything/blob/efeab7296ab579d4a261e554eca80faf6b33924a/segment_anything/modeling/sam.py#L140 + + const float preprocess_scale = std::max(nx, ny) / float(n_img_size); + const int cropped_nx = int(nx / preprocess_scale + 0.5f); + const int cropped_ny = int(ny / preprocess_scale + 0.5f); + + const float scale_x_1 = (float)ne0 / (float)n_img_size; + const float scale_y_1 = (float)ne1 / (float)n_img_size; + + const float scale_x_2 = float(cropped_nx) / float(nx); + const float scale_y_2 = float(cropped_ny) / float(ny); + + const auto iou_data = (float*)state.iou_predictions->data; + + for (int i = 0; i < ne2; ++i) { + if (iou_threshold > 0.f && iou_data[i] < iou_threshold) { + printf("Skipping mask %d with iou %f below threshold %f\n", i, iou_data[i], iou_threshold); + continue; // Filtering masks with iou below the threshold + } + + std::vector mask_data(n_img_size*n_img_size); + { + const float* data = (float *) state.low_res_masks->data + i*ne0*ne1; + + for (int iy = 0; iy < n_img_size; ++iy) { + for (int ix = 0; ix < n_img_size; ++ix) { + const float sx = std::max(scale_x_1*(ix + 0.5f) - 0.5f, 0.0f); + const float sy = std::max(scale_y_1*(iy + 0.5f) - 0.5f, 0.0f); + + const int x0 = std::max(0, (int)sx); + const int y0 = std::max(0, (int)sy); + + const int x1 = std::min(x0 + 1, ne0 - 1); + const int y1 = std::min(y0 + 1, ne1 - 1); + + const float dx = sx - x0; + const float dy = sy - y0; + + const int j00 = y0*ne0 + x0; + const int j01 = y0*ne0 + x1; + const int j10 = y1*ne0 + x0; + const int j11 = y1*ne0 + x1; + + const float v00 = data[j00]; + const float v01 = data[j01]; + const float v10 = data[j10]; + const float v11 = data[j11]; + + const float v0 = (1-dx)*v00 + dx*v01; + const float v1 = (1-dx)*v10 + dx*v11; + + const float v = (1-dy)*v0 + dy*v1; + + mask_data[iy*n_img_size + ix] = v; + } + } + } + + int intersections = 0; + int unions = 0; + sam_image_u8 res; + int min_iy = ny; + int max_iy = 0; + int min_ix = nx; + int max_ix = 0; + { + const float* data = mask_data.data(); + + res.nx = nx; + res.ny = ny; + res.data.resize(nx*ny); + + for (int iy = 0; iy < ny; ++iy) { + for (int ix = 0; ix < nx; ++ix) { + const float sx = std::max(scale_x_2*(ix + 0.5f) - 0.5f, 0.0f); + const float sy = std::max(scale_y_2*(iy + 0.5f) - 0.5f, 0.0f); + + const int x0 = std::max(0, (int)sx); + const int y0 = std::max(0, (int)sy); + + const int x1 = std::min(x0 + 1, cropped_nx - 1); + const int y1 = std::min(y0 + 1, cropped_ny - 1); + + const float dx = sx - x0; + const float dy = sy - y0; + + const int j00 = y0*n_img_size + x0; + const int j01 = y0*n_img_size + x1; + const int j10 = y1*n_img_size + x0; + const int j11 = y1*n_img_size + x1; + + const float v00 = data[j00]; + const float v01 = data[j01]; + const float v10 = data[j10]; + const float v11 = data[j11]; + + const float v0 = (1-dx)*v00 + dx*v01; + const float v1 = (1-dx)*v10 + dx*v11; + + const float v = (1-dy)*v0 + dy*v1; + + if (v > intersection_threshold) { + intersections++; + } + if (v > union_threshold) { + unions++; + } + if (v > mask_threshold) { + min_iy = std::min(min_iy, iy); + max_iy = std::max(max_iy, iy); + min_ix = std::min(min_ix, ix); + max_ix = std::max(max_ix, ix); + + res.data[iy*nx + ix] = 255; + } + } + } + } + + const float stability_score = float(intersections) / float(unions); + if (stability_score_threshold > 0.f && stability_score < stability_score_threshold) { + printf("Skipping mask %d with stability score %f below threshold %f\n", i, stability_score, stability_score_threshold); + continue; // Filtering masks with stability score below the threshold + } + + printf("Mask %d: iou = %f, stability_score = %f, bbox (%d, %d), (%d, %d)\n", + i, iou_data[i], stability_score, min_ix, max_ix, min_iy, max_iy); + + const std::string filename = multimask_output ? fname + std::to_string(i) + ".png" : fname + ".png"; + if (!stbi_write_png(filename.c_str(), res.nx, res.ny, 1, res.data.data(), res.nx)) { + printf("%s: failed to write mask %s\n", __func__, filename.c_str()); + return false; + } + } + + + return true; +} + + +struct ggml_cgraph * sam_build_fast_graph( + const sam_model & model, + sam_state & state, + const int nx, + const int ny, + const sam_prompt & prompt, + const bool multimask_output) { + + struct ggml_init_params ggml_params = { + /*.mem_size =*/ state.buf_compute_fast.size(), + /*.mem_buffer =*/ state.buf_compute_fast.data(), + /*.no_alloc =*/ true, // skip allocating as we use ggml_alloc to allocate exact memory requirements + }; + + struct ggml_context * ctx0 = ggml_init(ggml_params); + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + prompt_encoder_result enc_res = sam_encode_prompt(model, ctx0, gf, state, prompt); + if (!enc_res.embd_prompt_sparse || !enc_res.embd_prompt_dense) { + fprintf(stderr, "%s: failed to encode prompt\n", __func__); + return {}; + } + + struct ggml_tensor * pe_img_dense = sam_fill_dense_pe(model, ctx0, gf, state); + if (!pe_img_dense) { + fprintf(stderr, "%s: failed to get dense positional encoding\n", __func__); + return {}; + } + + if (!sam_decode_mask(model, enc_res, pe_img_dense, ctx0, gf, state, multimask_output)) { + fprintf(stderr, "%s: failed to decode mask\n", __func__); + return {}; + } + + ggml_free(ctx0); + + ggml_gallocr_alloc_graph(state.allocr, gf); + + struct ggml_tensor * inp = ggml_graph_get_tensor(gf, "prompt_input"); + auto * data = (float *) inp->data; + + // Transform prompt (point or box) + { + // https://github.com/facebookresearch/segment-anything/blob/dca509fe793f601edb92606367a655c15ac00fdf/segment_anything/utils/transforms.py#L33 + // The point scaling here is greatly simplified but mathematically equivalent. + const auto scale = 1.0F / std::max(nx, ny); + + switch (prompt.prompt_type) { + case SAM_PROMPT_TYPE_POINT: { + const auto & pt = prompt.pt; + + // set the input by converting the [0, 1] coordinates to [-1, 1] + data[0] = 2.0f*pt.x*scale - 1.0f; + data[1] = 2.0f*pt.y*scale - 1.0f; + + // padding + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L81-L85 + data[2] = 2.0f*(0.0f) - 1.0f; + data[3] = 2.0f*(0.0f) - 1.0f; + } break; + case SAM_PROMPT_TYPE_BOX: { + const auto & box = prompt.box; + + data[0] = 2.0f*box.x1*scale - 1.0f; + data[1] = 2.0f*box.y1*scale - 1.0f; + data[2] = 2.0f*box.x2*scale - 1.0f; + data[3] = 2.0f*box.y2*scale - 1.0f; + } break; + } + } + + // from sam_fill_dense_pe + { + struct ggml_tensor * xy_embed_stacked = ggml_graph_get_tensor(gf, "xy_embed_stacked"); + const int32_t n_img_embd = model.hparams.n_img_embd(); + const float n_img_embd_inv = 1.0f / n_img_embd; + float * data = (float *) ggml_get_data(xy_embed_stacked); + for (int i = 0; i < n_img_embd; ++i) { + const int row = 2*i*n_img_embd; + const float y_val = 2 * (i + 0.5f) * n_img_embd_inv - 1; + for (int j = 0; j < n_img_embd; ++j) { + const float x_val = 2 * (j + 0.5f) * n_img_embd_inv - 1; + data[row + 2*j + 0] = x_val; + data[row + 2*j + 1] = y_val; + } + } + } + + return gf; +} + +void sam_print_usage(int argc, char ** argv, const sam_params & params) { + fprintf(stderr, "usage: %s [options]\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); + fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); + fprintf(stderr, " -m FNAME, --model FNAME\n"); + fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); + fprintf(stderr, " -i FNAME, --inp FNAME\n"); + fprintf(stderr, " input file (default: %s)\n", params.fname_inp.c_str()); + fprintf(stderr, " -o FNAME, --out FNAME\n"); + fprintf(stderr, " mask file name prefix (default: %s)\n", params.fname_out.c_str()); + fprintf(stderr, " -sm, --single-mask\n"); + fprintf(stderr, " single mask output (default multi mask output)\n"); + fprintf(stderr, "SAM hyperparameters:\n"); + fprintf(stderr, " -mt FLOAT, --mask-threshold\n"); + fprintf(stderr, " mask threshold (default: %f)\n", params.mask_threshold); + fprintf(stderr, " -it FLOAT, --iou-threshold\n"); + fprintf(stderr, " iou threshold (default: %f)\n", params.iou_threshold); + fprintf(stderr, " -st FLOAT, --score-threshold\n"); + fprintf(stderr, " score threshold (default: %f)\n", params.stability_score_threshold); + fprintf(stderr, " -so FLOAT, --score-offset\n"); + fprintf(stderr, " score offset (default: %f)\n", params.stability_score_offset); + fprintf(stderr, " -e FLOAT, --epsilon\n"); + fprintf(stderr, " epsilon (default: %f)\n", params.eps); + fprintf(stderr, " -ed FLOAT, --epsilon-decoder-transformer\n"); + fprintf(stderr, " epsilon decoder transformer (default: %f)\n", params.eps_decoder_transformer); + fprintf(stderr, "SAM prompt:\n"); + fprintf(stderr, " -p [x,y], --point-prompt\n"); + fprintf(stderr, " point to be used as prompt for SAM (default: %f,%f). Must be in a format FLOAT,FLOAT \n", params.prompt.pt.x, params.prompt.pt.y); + fprintf(stderr, " -b [x1,y1,x2,y2], --box-prompt\n"); + fprintf(stderr, " box to be used as prompt for SAM (default: %f,%f,%f,%f). Must be in a format FLOAT,FLOAT,FLOAT,FLOAT \n", + params.prompt.box.x1, params.prompt.box.y1, params.prompt.box.x2, params.prompt.box.y2); + fprintf(stderr, "\n"); +} + +bool sam_params_parse(int argc, char ** argv, sam_params & params) { + + bool use_point_prompt = false; + bool use_box_prompt = false; + + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "-s" || arg == "--seed") { + params.seed = std::stoi(argv[++i]); + } else if (arg == "-t" || arg == "--threads") { + params.n_threads = std::stoi(argv[++i]); + } else if (arg == "-m" || arg == "--model") { + params.model = argv[++i]; + } else if (arg == "-i" || arg == "--inp") { + params.fname_inp = argv[++i]; + } else if (arg == "-o" || arg == "--out") { + params.fname_out = argv[++i]; + } else if (arg == "-sm" || arg == "--single-mask") { + params.multimask_output = false; + } else if (arg == "-mt" || arg == "--mask-threshold") { + params.mask_threshold = std::stof(argv[++i]); + } else if (arg == "-it" || arg == "--iou-threshold") { + params.iou_threshold = std::stof(argv[++i]); + } else if (arg == "-st" || arg == "--score-threshold") { + params.stability_score_threshold = std::stof(argv[++i]); + } else if (arg == "-so" || arg == "--score-offset") { + params.stability_score_offset = std::stof(argv[++i]); + } else if (arg == "-e" || arg == "--epsilon") { + params.eps = std::stof(argv[++i]); + } else if (arg == "-ed" || arg == "--epsilon-decoder-transformer") { + params.eps_decoder_transformer = std::stof(argv[++i]); + } else if (arg == "-p" || arg == "--point-prompt") { + // TODO multiple points per model invocation + use_point_prompt = true; + char* point = argv[++i]; + + char* coord = strtok(point, ","); + if (!coord){ + fprintf(stderr, "Error while parsing prompt!\n"); + exit(1); + } + params.prompt.pt.x = std::stof(coord); + + coord = strtok(NULL, ","); + if (!coord){ + fprintf(stderr, "Error while parsing prompt!\n"); + exit(1); + } + params.prompt.pt.y = std::stof(coord); + } else if (arg == "-b" || arg == "--box-prompt") { + use_box_prompt = true; + char * box_prompt = argv[++i]; + float box_vals[4]; + + char * val = strtok(box_prompt, ","); + if (!val) { + fprintf(stderr, "Error while parsing prompt!\n"); + exit(1); + } + box_vals[0] = std::stof(val); + + for (int j = 1; j < 4; ++j) { + char * val = strtok(NULL, ","); + if (!val) { + fprintf(stderr, "Error while parsing prompt!\n"); + exit(1); + } + box_vals[j] = std::stof(val); + } + + params.prompt.box.x1 = box_vals[0]; + params.prompt.box.y1 = box_vals[1]; + params.prompt.box.x2 = box_vals[2]; + params.prompt.box.y2 = box_vals[3]; + } else if (arg == "-h" || arg == "--help") { + sam_print_usage(argc, argv, params); + exit(0); + } else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + sam_print_usage(argc, argv, params); + exit(0); + } + } + + if (use_box_prompt && use_point_prompt) { + fprintf(stderr, "Error: use either point or box prompt, not both.\n"); + exit(1); + } + + params.prompt.prompt_type = SAM_PROMPT_TYPE_POINT; + if (use_box_prompt) { + params.prompt.prompt_type = SAM_PROMPT_TYPE_BOX; + } + + return true; +} + + +int main(int argc, char ** argv) { + const int64_t t_main_start_us = ggml_time_us(); + + sam_params params; + params.model = "models/sam-vit-b/ggml-model-f16.bin"; + + sam_model model; + sam_state state; + int64_t t_load_us = 0; + + if (sam_params_parse(argc, argv, params) == false) { + return 1; + } + + if (params.seed < 0) { + params.seed = time(NULL); + } + fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); + + // load the image + sam_image_u8 img0; + if (!sam_image_load_from_file(params.fname_inp, img0)) { + fprintf(stderr, "%s: failed to load image from '%s'\n", __func__, params.fname_inp.c_str()); + return 1; + } + fprintf(stderr, "%s: loaded image '%s' (%d x %d)\n", __func__, params.fname_inp.c_str(), img0.nx, img0.ny); + + // preprocess to f32 + sam_image_f32 img1; + if (!sam_image_preprocess(img0, img1)) { + fprintf(stderr, "%s: failed to preprocess image\n", __func__); + return 1; + } + fprintf(stderr, "%s: preprocessed image (%d x %d)\n", __func__, img1.nx, img1.ny); + + + // load the model + { + const int64_t t_start_us = ggml_time_us(); + + if (!sam_model_load(params, model)) { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); + return 1; + } + + t_load_us = ggml_time_us() - t_start_us; + } + + { + static size_t buf_size = 256u*1024*1024; + + struct ggml_init_params ggml_params = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ false, + }; + + state.ctx = ggml_init(ggml_params); + + state.embd_img = ggml_new_tensor_3d(state.ctx, GGML_TYPE_F32, + model.hparams.n_img_embd(), model.hparams.n_img_embd(), model.hparams.n_enc_out_chans); + + state.low_res_masks = ggml_new_tensor_3d(state.ctx, GGML_TYPE_F32, + model.hparams.n_enc_out_chans, model.hparams.n_enc_out_chans, 3); + + state.iou_predictions = ggml_new_tensor_1d(state.ctx, GGML_TYPE_F32, 3); + } + + // Encode image + { + state.buf_compute_img_enc.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead()); + state.allocr = ggml_gallocr_new(ggml_backend_cpu_buffer_type()); + + struct ggml_cgraph * gf = sam_encode_image(model, state, img1); + if (!gf) { + fprintf(stderr, "%s: failed to encode image\n", __func__); + return 1; + } + + ggml_graph_compute_helper(state.work_buffer, gf, params.n_threads); + + // print_t_f32("embd_img", state.embd_img); + + ggml_gallocr_free(state.allocr); + state.allocr = NULL; + state.work_buffer.clear(); + } + + // Encode prompt and decode mask + { + state.buf_compute_fast.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead()); + state.allocr = ggml_gallocr_new(ggml_backend_cpu_buffer_type()); + + switch (params.prompt.prompt_type) { + case SAM_PROMPT_TYPE_POINT: + fprintf(stderr, "Using point prompt: (%f, %f)\n", params.prompt.pt.x, params.prompt.pt.y); + break; + case SAM_PROMPT_TYPE_BOX: + fprintf(stderr, "Using box prompt: (%f, %f, %f, %f)\n", + params.prompt.box.x1, + params.prompt.box.y1, + params.prompt.box.x2, + params.prompt.box.y2); + break; + } + + struct ggml_cgraph * gf = sam_build_fast_graph(model, state, img0.nx, img0.ny, params.prompt, params.multimask_output); + if (!gf) { + fprintf(stderr, "%s: failed to build fast graph\n", __func__); + return 1; + } + + ggml_graph_compute_helper(state.work_buffer, gf, params.n_threads); + + //print_t_f32("iou_predictions", state.iou_predictions); + //print_t_f32("low_res_masks", state.low_res_masks); + ggml_gallocr_free(state.allocr); + state.allocr = NULL; + } + + if (!sam_write_masks(model.hparams, img0.nx, img0.ny, state, params.fname_out, params.multimask_output)) { + fprintf(stderr, "%s: failed to write masks\n", __func__); + return 1; + } + + // report timing + { + const int64_t t_main_end_us = ggml_time_us(); + + fprintf(stderr, "\n\n"); + fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); + fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); + } + + ggml_free(model.ctx); + + return 0; +}