From: Yavor Ivanov Date: Tue, 5 Sep 2023 11:40:17 +0000 (+0300) Subject: sam : remove ggml_repeat and use inplace operation (#493) X-Git-Tag: upstream/0.0.1642~1250 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=26719f92533e8285c5ecf7680d77b220da2e4934;p=pkg%2Fggml%2Fsources%2Fggml sam : remove ggml_repeat and use inplace operation (#493) --- diff --git a/examples/sam/README.md b/examples/sam/README.md index d8702f1d..1e807c74 100644 --- a/examples/sam/README.md +++ b/examples/sam/README.md @@ -9,7 +9,7 @@ The example currently supports only the [ViT-B SAM model checkpoint](https://hug ## Next steps - [X] Reduce memory usage by utilizing the new ggml-alloc -- [ ] Remove redundant graph nodes +- [X] Remove redundant graph nodes - [ ] Make inference faster - [X] Fix the difference in output masks compared to the PyTorch implementation - [X] Filter masks based on stability score @@ -28,14 +28,14 @@ cd ggml python3 -m pip install -r requirements.txt # Convert PTH model to ggml -python convert-pth-to-ggml.py examples/sam/sam_vit_b_01ec64.pth 1 +python convert-pth-to-ggml.py examples/sam/sam_vit_b_01ec64.pth . 1 # Build ggml + examples mkdir build && cd build cmake .. && make -j4 # run inference -./bin/sam -t 16 -i ../img.jpg -m ../examples/sam/ggml-model-f16.bin +./bin/sam -t 16 -i ../img.jpg -m examples/sam/ggml-model-f16.bin ``` ## Downloading and converting the model checkpoints @@ -44,17 +44,20 @@ You can download a [model checkpoint](https://github.com/facebookresearch/segmen ``` # Convert PTH model to ggml -python convert-pth-to-ggml.py examples/sam/sam_vit_b_01ec64.pth 1 +python convert-pth-to-ggml.py examples/sam/sam_vit_b_01ec64.pth . 1 ``` -## Example output +## Example output on M2 Ultra ``` -$ ./bin/sam -t 16 -i ../img.jpg -m ../examples/sam/ggml-model-f16.bin -main: seed = 1692347524 -main: loaded image '../img.jpg' (680 x 453) + $ ▶ make -j sam && time ./bin/sam -t 8 -i img.jpg +[ 28%] Built target common +[ 71%] Built target ggml +[100%] Built target sam +main: seed = 1693224265 +main: loaded image 'img.jpg' (680 x 453) sam_image_preprocess: scale = 0.664062 main: preprocessed image (1024 x 1024) -sam_model_load: loading model from '../examples/sam/ggml-model-f16.bin' - please wait ... +sam_model_load: loading model from 'models/sam-vit-b/ggml-model-f16.bin' - please wait ... sam_model_load: n_enc_state = 768 sam_model_load: n_enc_layer = 12 sam_model_load: n_enc_head = 12 @@ -65,11 +68,24 @@ sam_model_load: qntvr = 0 operator(): ggml ctx size = 202.32 MB sam_model_load: ...................................... done sam_model_load: model size = 185.05 MB / num tensors = 304 -point: 624.500000 245.593750 +embd_img +dims: 64 64 256 1 f32 +First & Last 10 elements: +-0.05117 -0.06408 -0.07154 -0.06991 -0.07212 -0.07690 -0.07508 -0.07281 -0.07383 -0.06779 +0.01589 0.01775 0.02250 0.01675 0.01766 0.01661 0.01811 0.02051 0.02103 0.03382 +sum: 12736.272313 +Skipping mask 0 with iou 0.705935 below threshold 0.880000 +Skipping mask 1 with iou 0.762136 below threshold 0.880000 +Mask 2: iou = 0.947081, stability_score = 0.955437, bbox (371, 436), (144, 168) -main: load time = 88.36 ms -main: total time = 5697.57 ms + +main: load time = 51.28 ms +main: total time = 2047.49 ms + +real 0m2.068s +user 0m16.343s +sys 0m0.214s ``` Input point is (414.375, 162.796875) (currently hardcoded) @@ -78,9 +94,9 @@ Input image: ![llamas](https://user-images.githubusercontent.com/8558655/261301565-37b7bf4b-bf91-40cf-8ec1-1532316e1612.jpg) -Output mask: +Output mask (mask_out_2.png in build folder): -![mask_glasses](https://user-images.githubusercontent.com/8558655/261301844-9fc2dbbc-5fd6-42ce-af69-643df9e6fad1.png) +![mask_glasses](https://user-images.githubusercontent.com/8558655/263706800-47eeea30-1457-4c87-938b-8f11536c5aa7.png) ## References diff --git a/examples/sam/convert-pth-to-ggml.py b/examples/sam/convert-pth-to-ggml.py index 5f97f0fb..0de422e5 100644 --- a/examples/sam/convert-pth-to-ggml.py +++ b/examples/sam/convert-pth-to-ggml.py @@ -1,23 +1,21 @@ # Convert a SAM model checkpoint to a ggml compatible file # -import os import sys -import code -import json import torch import struct import numpy as np if len(sys.argv) < 3: - print("Usage: convert-pth-to-ggml.py file-model ftype\n") + print("Usage: convert-pth-to-ggml.py file-model dir-output [ftype]\n") print(" ftype == 0 -> float32") print(" ftype == 1 -> float16") sys.exit(1) # output in the same directory as the model fname_model = sys.argv[1] -fname_out = os.path.dirname(fname_model) + "/ggml-model.bin" +dir_out = sys.argv[2] +fname_out = dir_out + "/ggml-model.bin" # possible data types # ftype == 0 -> float32 @@ -27,8 +25,8 @@ fname_out = os.path.dirname(fname_model) + "/ggml-model.bin" ftype_str = ["f32", "f16"] ftype = 1 -if len(sys.argv) > 2: - ftype = int(sys.argv[2]) +if len(sys.argv) > 3: + ftype = int(sys.argv[3]) if ftype < 0 or ftype > 1: print("Invalid ftype: " + str(ftype)) diff --git a/examples/sam/main.cpp b/examples/sam/main.cpp index 7a140006..a3469021 100644 --- a/examples/sam/main.cpp +++ b/examples/sam/main.cpp @@ -1067,7 +1067,7 @@ bool sam_model_load(const std::string & fname, sam_model & model) { } } - if (n_tensors != ptrdiff_t(model.tensors.size())) { + if (n_tensors != int(model.tensors.size())) { fprintf(stderr, "%s: model file has %d tensors, but %d tensors were expected\n", __func__, n_tensors, (int) model.tensors.size()); return false; } @@ -1201,11 +1201,9 @@ struct ggml_cgraph * sam_encode_image( // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L392 struct ggml_tensor * cur = ggml_conv_2d_sk_p0(ctx0, enc.proj_w, inp); - cur = ggml_add(ctx0, - ggml_repeat(ctx0, - enc.proj_b, - cur), - cur); + cur = ggml_add_inplace(ctx0, + cur, + ggml_repeat(ctx0, enc.proj_b, cur)); // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L394 // keep in F32 @@ -1218,7 +1216,7 @@ struct ggml_cgraph * sam_encode_image( // ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_enc_state, n_img_embd, n_img_embd)); // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L108-L109 - cur = ggml_add(ctx0, enc.pe, cur); + cur = ggml_add_inplace(ctx0, cur, enc.pe); struct ggml_tensor * inpL = cur; @@ -1231,11 +1229,8 @@ struct ggml_cgraph * sam_encode_image( cur = ggml_norm(ctx0, inpL, hparams.eps); // cur = ln_0_w*cur + ln_0_b - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, layer.norm1_w, cur), - cur), - ggml_repeat(ctx0, layer.norm1_b, cur)); + cur = ggml_mul(ctx0, cur, layer.norm1_w); + cur = ggml_add_inplace(ctx0, cur, layer.norm1_b); } const int64_t w0 = cur->ne[1]; @@ -1253,11 +1248,7 @@ struct ggml_cgraph * sam_encode_image( // self-attention { cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); - cur = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.qkv_b, - cur), - cur); + cur = ggml_add_inplace(ctx0, cur, layer.qkv_b); // split qkv into separate tensors // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L225-L229 @@ -1320,11 +1311,7 @@ struct ggml_cgraph * sam_encode_image( n_enc_state, W, H, B); cur = ggml_mul_mat(ctx0, layer.proj_w, cur); - cur = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.proj_b, - cur), - cur); + cur = ggml_add_inplace(ctx0, cur, layer.proj_b); } if (hparams.is_global_attn(il) == false) { @@ -1332,7 +1319,7 @@ struct ggml_cgraph * sam_encode_image( cur = ggml_win_unpart(ctx0, cur, w0, h0, n_window_size); } - cur = ggml_add(ctx0, inpL, cur); + cur = ggml_add_inplace(ctx0, cur, inpL); struct ggml_tensor * inpFF = cur; @@ -1343,33 +1330,20 @@ struct ggml_cgraph * sam_encode_image( cur = ggml_norm(ctx0, inpFF, hparams.eps); // cur = mlp_ln_w*cur + mlp_ln_b - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, layer.norm2_w, cur), - cur), - ggml_repeat(ctx0, layer.norm2_b, cur)); + cur = ggml_mul(ctx0, cur, layer.norm2_w); + cur = ggml_add_inplace(ctx0, cur, layer.norm2_b); } // fully connected - cur = ggml_mul_mat(ctx0, - layer.mlp_lin1_w, - cur); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.mlp_lin1_b, cur), - cur); + cur = ggml_mul_mat(ctx0, layer.mlp_lin1_w, cur); + cur = ggml_add_inplace(ctx0, cur, layer.mlp_lin1_b); // GELU activation cur = ggml_gelu(ctx0, cur); // projection - cur = ggml_mul_mat(ctx0, - layer.mlp_lin2_w, - cur); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.mlp_lin2_b, cur), - cur); + cur = ggml_mul_mat(ctx0, layer.mlp_lin2_w, cur); + cur = ggml_add_inplace(ctx0, cur, layer.mlp_lin2_b); } inpL = ggml_add(ctx0, cur, inpFF); @@ -1494,8 +1468,7 @@ prompt_encoder_result sam_encode_prompt( prompt_encoder_result res; res.embd_prompt_sparse = embd_prompt_sparse; - res.embd_prompt_dense = embd_prompt_dense; - + res.embd_prompt_dense = embd_prompt_dense; return res; } @@ -1514,19 +1487,13 @@ struct ggml_tensor* sam_decode_mask_transformer_attn( struct ggml_tensor * Vcur = {}; Qcur = ggml_mul_mat(ctx0, attn.q_w, queries); - Qcur = ggml_add(ctx0, - ggml_repeat(ctx0, attn.q_b, Qcur), - Qcur); + Qcur = ggml_add_inplace(ctx0, Qcur, attn.q_b); Kcur = ggml_mul_mat(ctx0, attn.k_w, keys); - Kcur = ggml_add(ctx0, - ggml_repeat(ctx0, attn.k_b, Kcur), - Kcur); + Kcur = ggml_add_inplace(ctx0, Kcur, attn.k_b); Vcur = ggml_mul_mat(ctx0, attn.v_w, values); - Vcur = ggml_add(ctx0, - ggml_repeat(ctx0, attn.v_b, Vcur), - Vcur); + Vcur = ggml_add_inplace(ctx0, Vcur, attn.v_b); struct ggml_tensor * Q = {}; struct ggml_tensor * K = {}; @@ -1557,9 +1524,7 @@ struct ggml_tensor* sam_decode_mask_transformer_attn( KQV_merged = ggml_cont(ctx0, ggml_permute(ctx0, KQV_merged, 0, 2, 1, 3)); KQV_merged = ggml_reshape_3d(ctx0, KQV_merged, KQV_merged->ne[0]*KQV_merged->ne[1], KQV_merged->ne[2], KQV_merged->ne[3]); KQV_merged = ggml_mul_mat(ctx0, attn.out_w, KQV_merged); - KQV_merged = ggml_add(ctx0, - ggml_repeat(ctx0, attn.out_b, KQV_merged), - KQV_merged); + KQV_merged = ggml_add_inplace(ctx0, KQV_merged, attn.out_b); return KQV_merged; } @@ -1576,21 +1541,17 @@ struct ggml_tensor * sam_decode_mask_mlp_relu_3( struct ggml_tensor * cur = {}; cur = ggml_mul_mat(ctx0, w_0, in); - cur = ggml_add(ctx0, - ggml_repeat(ctx0, b_0, cur), - cur); - cur = ggml_relu(ctx0, cur); + cur = ggml_add_inplace(ctx0, cur, b_0); + + cur = ggml_relu_inplace(ctx0, cur); cur = ggml_mul_mat(ctx0, w_1, cur); - cur = ggml_add(ctx0, - ggml_repeat(ctx0, b_1, cur), - cur); - cur = ggml_relu(ctx0, cur); + cur = ggml_add_inplace(ctx0, cur, b_1); + + cur = ggml_relu_inplace(ctx0, cur); cur = ggml_mul_mat(ctx0, w_2, cur); - cur = ggml_add(ctx0, - ggml_repeat(ctx0, b_2, cur), - cur); + cur = ggml_add_inplace(ctx0, cur, b_2); return cur; } @@ -1692,15 +1653,13 @@ bool sam_decode_mask( struct ggml_tensor * q_0 = ggml_add(ctx0, queries, tokens); struct ggml_tensor * self_attn = sam_decode_mask_transformer_attn(tfm_layer.self_attn, q_0, q_0, queries, ctx0, model); - queries = ggml_add(ctx0, queries, self_attn); + queries = ggml_add_inplace(ctx0, queries, self_attn); } queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer); - queries = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, tfm_layer.norm1_w, queries), - queries), - ggml_repeat(ctx0, tfm_layer.norm1_b, queries)); + queries = ggml_add_inplace(ctx0, + ggml_mul(ctx0, queries, tfm_layer.norm1_w), + tfm_layer.norm1_b); // Cross attention block, tokens attending to image embedding // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L163 @@ -1709,13 +1668,11 @@ bool sam_decode_mask( struct ggml_tensor * cross_attn_token_to_img = sam_decode_mask_transformer_attn(tfm_layer.cross_attn_token_to_img, q_1, k_1, keys, ctx0, model); - queries = ggml_add(ctx0, queries, cross_attn_token_to_img); - queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer); - queries = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, tfm_layer.norm2_w, queries), - queries), - ggml_repeat(ctx0, tfm_layer.norm2_b, queries)); + queries = ggml_add_inplace(ctx0, queries, cross_attn_token_to_img); + queries = ggml_norm_inplace(ctx0, queries, hparams.eps_decoder_transformer); + queries = ggml_add_inplace(ctx0, + ggml_mul(ctx0, queries, tfm_layer.norm2_w), + tfm_layer.norm2_b); // MLP block // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L170 @@ -1723,27 +1680,19 @@ bool sam_decode_mask( tfm_layer.mlp_lin1_w, queries); - mlp_out = ggml_add(ctx0, - ggml_repeat(ctx0, tfm_layer.mlp_lin1_b, mlp_out), - mlp_out); + mlp_out = ggml_add_inplace(ctx0, mlp_out, tfm_layer.mlp_lin1_b); // RELU activation - mlp_out = ggml_relu(ctx0, mlp_out); - mlp_out = ggml_mul_mat(ctx0, - tfm_layer.mlp_lin2_w, - mlp_out); + mlp_out = ggml_relu_inplace(ctx0, mlp_out); + mlp_out = ggml_mul_mat(ctx0, tfm_layer.mlp_lin2_w, mlp_out); - mlp_out = ggml_add(ctx0, - ggml_repeat(ctx0, tfm_layer.mlp_lin2_b, mlp_out), - mlp_out); + mlp_out = ggml_add_inplace(ctx0, mlp_out, tfm_layer.mlp_lin2_b); - queries = ggml_add(ctx0, queries, mlp_out); - queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer); - queries = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, tfm_layer.norm3_w, queries), - queries), - ggml_repeat(ctx0, tfm_layer.norm3_b, queries)); + queries = ggml_add_inplace(ctx0, queries, mlp_out); + queries = ggml_norm_inplace(ctx0, queries, hparams.eps_decoder_transformer); + queries = ggml_add_inplace(ctx0, + ggml_mul(ctx0, queries, tfm_layer.norm3_w), + tfm_layer.norm3_b); // Cross attention block, image embedding attending to tokens // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L175 @@ -1751,13 +1700,11 @@ bool sam_decode_mask( struct ggml_tensor * k_2 = ggml_add(ctx0, keys, pos_src); struct ggml_tensor * cross_attn_img_to_token = sam_decode_mask_transformer_attn(tfm_layer.cross_attn_img_to_token, k_2, q_2, queries, ctx0, model); - keys = ggml_add(ctx0, keys, cross_attn_img_to_token); - keys = ggml_norm(ctx0, keys, hparams.eps_decoder_transformer); - keys = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, tfm_layer.norm4_w, keys), - keys), - ggml_repeat(ctx0, tfm_layer.norm4_b, keys)); + keys = ggml_add_inplace(ctx0, keys, cross_attn_img_to_token); + keys = ggml_norm_inplace(ctx0, keys, hparams.eps_decoder_transformer); + keys = ggml_add_inplace(ctx0, + ggml_mul(ctx0, keys, tfm_layer.norm4_w), + tfm_layer.norm4_b); } // Apply the final attention layer from the points to the image @@ -1767,13 +1714,11 @@ bool sam_decode_mask( struct ggml_tensor * final_attn_token_to_img = sam_decode_mask_transformer_attn(dec.transformer_final_attn_token_to_img, q, k, keys, ctx0, model); - queries = ggml_add(ctx0, queries, final_attn_token_to_img); - queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer); - queries = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, dec.transformer_norm_final_w, queries), - queries), - ggml_repeat(ctx0, dec.transformer_norm_final_b, queries)); + queries = ggml_add_inplace(ctx0, queries, final_attn_token_to_img); + queries = ggml_norm_inplace(ctx0, queries, hparams.eps_decoder_transformer); + queries = ggml_add_inplace(ctx0, + ggml_mul(ctx0, queries, dec.transformer_norm_final_w), + dec.transformer_norm_final_b); } @@ -1790,23 +1735,23 @@ bool sam_decode_mask( // ConvTranspose2d keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_0_w, keys, 2); ggml_allocr_alloc(state.allocr, keys); // TODO: This alloc shouldn't be needed - keys = ggml_add(ctx0, keys, ggml_repeat(ctx0, + keys = ggml_add_inplace(ctx0, keys, ggml_repeat(ctx0, ggml_reshape_3d(ctx0, dec.output_upscaling_0_b, 1, 1, dec.output_upscaling_0_b->ne[0]), keys)); keys = sam_layer_norm_2d(ctx0, keys, n_img_embd, dec.output_upscaling_1_w, dec.output_upscaling_1_b, hparams.eps); // GELU activation - keys = ggml_gelu(ctx0, keys); + keys = ggml_gelu_inplace(ctx0, keys); // ConvTranspose2d keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_3_w, keys, 2); ggml_allocr_alloc(state.allocr, keys); // TODO: This alloc shouldn't be needed - keys = ggml_add(ctx0, ggml_repeat(ctx0, + keys = ggml_add_inplace(ctx0, ggml_repeat(ctx0, ggml_reshape_3d(ctx0, dec.output_upscaling_3_b, 1, 1, dec.output_upscaling_3_b->ne[0]), keys), keys); // GELU activation - keys = ggml_gelu(ctx0, keys); + keys = ggml_gelu_inplace(ctx0, keys); upscaled_embedding = ggml_reshape_3d(ctx0, keys, keys->ne[0]*keys->ne[1], keys->ne[2], keys->ne[3]); upscaled_embedding = ggml_cont(ctx0, ggml_transpose(ctx0, upscaled_embedding)); // TODO: Shouldn't be needed }