]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sam : remove ggml_repeat and use inplace operation (#493)
authorYavor Ivanov <redacted>
Tue, 5 Sep 2023 11:40:17 +0000 (14:40 +0300)
committerGitHub <redacted>
Tue, 5 Sep 2023 11:40:17 +0000 (14:40 +0300)
examples/sam/README.md
examples/sam/convert-pth-to-ggml.py
examples/sam/main.cpp

index d8702f1d557bad661d95dde3cadbc94c102c600d..1e807c7405dd32075ab18c5da1e05fdd02c82bdd 100644 (file)
@@ -9,7 +9,7 @@ The example currently supports only the [ViT-B SAM model checkpoint](https://hug
 ## Next steps
 
 - [X] Reduce memory usage by utilizing the new ggml-alloc
-- [ ] Remove redundant graph nodes
+- [X] Remove redundant graph nodes
 - [ ] Make inference faster
 - [X] Fix the difference in output masks compared to the PyTorch implementation
 - [X] Filter masks based on stability score
@@ -28,14 +28,14 @@ cd ggml
 python3 -m pip install -r requirements.txt
 
 # Convert PTH model to ggml
-python convert-pth-to-ggml.py examples/sam/sam_vit_b_01ec64.pth 1
+python convert-pth-to-ggml.py examples/sam/sam_vit_b_01ec64.pth 1
 
 # Build ggml + examples
 mkdir build && cd build
 cmake .. && make -j4
 
 # run inference
-./bin/sam -t 16 -i ../img.jpg -m ../examples/sam/ggml-model-f16.bin
+./bin/sam -t 16 -i ../img.jpg -m examples/sam/ggml-model-f16.bin
 ```
 
 ## Downloading and converting the model checkpoints
@@ -44,17 +44,20 @@ You can download a [model checkpoint](https://github.com/facebookresearch/segmen
 
 ```
 # Convert PTH model to ggml
-python convert-pth-to-ggml.py examples/sam/sam_vit_b_01ec64.pth 1
+python convert-pth-to-ggml.py examples/sam/sam_vit_b_01ec64.pth 1
 ```
 
-## Example output
+## Example output on M2 Ultra
 ```
-$ ./bin/sam -t 16 -i ../img.jpg -m ../examples/sam/ggml-model-f16.bin
-main: seed = 1692347524
-main: loaded image '../img.jpg' (680 x 453)
+ $ ▶ make -j sam && time ./bin/sam -t 8 -i img.jpg
+[ 28%] Built target common
+[ 71%] Built target ggml
+[100%] Built target sam
+main: seed = 1693224265
+main: loaded image 'img.jpg' (680 x 453)
 sam_image_preprocess: scale = 0.664062
 main: preprocessed image (1024 x 1024)
-sam_model_load: loading model from '../examples/sam/ggml-model-f16.bin' - please wait ...
+sam_model_load: loading model from 'models/sam-vit-b/ggml-model-f16.bin' - please wait ...
 sam_model_load: n_enc_state      = 768
 sam_model_load: n_enc_layer      = 12
 sam_model_load: n_enc_head       = 12
@@ -65,11 +68,24 @@ sam_model_load: qntvr            = 0
 operator(): ggml ctx size = 202.32 MB
 sam_model_load: ...................................... done
 sam_model_load: model size =   185.05 MB / num tensors = 304
-point: 624.500000 245.593750
+embd_img
+dims: 64 64 256 1 f32
+First & Last 10 elements:
+-0.05117 -0.06408 -0.07154 -0.06991 -0.07212 -0.07690 -0.07508 -0.07281 -0.07383 -0.06779
+0.01589 0.01775 0.02250 0.01675 0.01766 0.01661 0.01811 0.02051 0.02103 0.03382
+sum:  12736.272313
 
+Skipping mask 0 with iou 0.705935 below threshold 0.880000
+Skipping mask 1 with iou 0.762136 below threshold 0.880000
+Mask 2: iou = 0.947081, stability_score = 0.955437, bbox (371, 436), (144, 168)
 
-main:     load time =    88.36 ms
-main:    total time =  5697.57 ms
+
+main:     load time =    51.28 ms
+main:    total time =  2047.49 ms
+
+real   0m2.068s
+user   0m16.343s
+sys    0m0.214s
 ```
 
 Input point is (414.375, 162.796875) (currently hardcoded)
@@ -78,9 +94,9 @@ Input image:
 
 ![llamas](https://user-images.githubusercontent.com/8558655/261301565-37b7bf4b-bf91-40cf-8ec1-1532316e1612.jpg)
 
-Output mask:
+Output mask (mask_out_2.png in build folder):
 
-![mask_glasses](https://user-images.githubusercontent.com/8558655/261301844-9fc2dbbc-5fd6-42ce-af69-643df9e6fad1.png)
+![mask_glasses](https://user-images.githubusercontent.com/8558655/263706800-47eeea30-1457-4c87-938b-8f11536c5aa7.png)
 
 ## References
 
index 5f97f0fb43147a3cccb35270637295875922645f..0de422e5517d58de334ec4ecd58aa6528c56fc48 100644 (file)
@@ -1,23 +1,21 @@
 # Convert a SAM model checkpoint to a ggml compatible file
 #
 
-import os
 import sys
-import code
-import json
 import torch
 import struct
 import numpy as np
 
 if len(sys.argv) < 3:
-    print("Usage: convert-pth-to-ggml.py file-model ftype\n")
+    print("Usage: convert-pth-to-ggml.py file-model dir-output [ftype]\n")
     print("  ftype == 0 -> float32")
     print("  ftype == 1 -> float16")
     sys.exit(1)
 
 # output in the same directory as the model
 fname_model = sys.argv[1]
-fname_out   = os.path.dirname(fname_model) + "/ggml-model.bin"
+dir_out     = sys.argv[2]
+fname_out   = dir_out + "/ggml-model.bin"
 
 # possible data types
 #   ftype == 0 -> float32
@@ -27,8 +25,8 @@ fname_out   = os.path.dirname(fname_model) + "/ggml-model.bin"
 ftype_str = ["f32", "f16"]
 
 ftype = 1
-if len(sys.argv) > 2:
-    ftype = int(sys.argv[2])
+if len(sys.argv) > 3:
+    ftype = int(sys.argv[3])
 
 if ftype < 0 or ftype > 1:
     print("Invalid ftype: " + str(ftype))
index 7a140006504df3f928364431730759860ef5c8c5..a3469021a23c3ba89c1017f0a0b257a872b5d2c4 100644 (file)
@@ -1067,7 +1067,7 @@ bool sam_model_load(const std::string & fname, sam_model & model) {
             }
         }
 
-        if (n_tensors != ptrdiff_t(model.tensors.size())) {
+        if (n_tensors != int(model.tensors.size())) {
             fprintf(stderr, "%s: model file has %d tensors, but %d tensors were expected\n", __func__, n_tensors, (int) model.tensors.size());
             return false;
         }
@@ -1201,11 +1201,9 @@ struct ggml_cgraph  * sam_encode_image(
 
     // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L392
     struct ggml_tensor * cur = ggml_conv_2d_sk_p0(ctx0, enc.proj_w, inp);
-    cur = ggml_add(ctx0,
-            ggml_repeat(ctx0,
-                enc.proj_b,
-                cur),
-            cur);
+    cur = ggml_add_inplace(ctx0,
+            cur,
+            ggml_repeat(ctx0, enc.proj_b, cur));
 
     // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L394
     // keep in F32
@@ -1218,7 +1216,7 @@ struct ggml_cgraph  * sam_encode_image(
     //        ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_enc_state, n_img_embd, n_img_embd));
 
     // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L108-L109
-    cur = ggml_add(ctx0, enc.pe, cur);
+    cur = ggml_add_inplace(ctx0, cur, enc.pe);
 
     struct ggml_tensor * inpL = cur;
 
@@ -1231,11 +1229,8 @@ struct ggml_cgraph  * sam_encode_image(
             cur = ggml_norm(ctx0, inpL, hparams.eps);
 
             // cur = ln_0_w*cur + ln_0_b
-            cur = ggml_add(ctx0,
-                    ggml_mul(ctx0,
-                        ggml_repeat(ctx0, layer.norm1_w, cur),
-                        cur),
-                    ggml_repeat(ctx0, layer.norm1_b, cur));
+            cur = ggml_mul(ctx0, cur, layer.norm1_w);
+            cur = ggml_add_inplace(ctx0, cur, layer.norm1_b);
         }
 
         const int64_t w0 = cur->ne[1];
@@ -1253,11 +1248,7 @@ struct ggml_cgraph  * sam_encode_image(
         // self-attention
         {
             cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
-            cur = ggml_add(ctx0,
-                    ggml_repeat(ctx0,
-                        layer.qkv_b,
-                        cur),
-                    cur);
+            cur = ggml_add_inplace(ctx0, cur, layer.qkv_b);
 
             // split qkv into separate tensors
             // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L225-L229
@@ -1320,11 +1311,7 @@ struct ggml_cgraph  * sam_encode_image(
                         n_enc_state, W, H, B);
 
             cur = ggml_mul_mat(ctx0, layer.proj_w, cur);
-            cur = ggml_add(ctx0,
-                    ggml_repeat(ctx0,
-                        layer.proj_b,
-                        cur),
-                    cur);
+            cur = ggml_add_inplace(ctx0, cur, layer.proj_b);
         }
 
         if (hparams.is_global_attn(il) == false) {
@@ -1332,7 +1319,7 @@ struct ggml_cgraph  * sam_encode_image(
             cur = ggml_win_unpart(ctx0, cur, w0, h0, n_window_size);
         }
 
-        cur = ggml_add(ctx0, inpL, cur);
+        cur = ggml_add_inplace(ctx0, cur, inpL);
 
         struct ggml_tensor * inpFF = cur;
 
@@ -1343,33 +1330,20 @@ struct ggml_cgraph  * sam_encode_image(
                 cur = ggml_norm(ctx0, inpFF, hparams.eps);
 
                 // cur = mlp_ln_w*cur + mlp_ln_b
-                cur = ggml_add(ctx0,
-                        ggml_mul(ctx0,
-                            ggml_repeat(ctx0, layer.norm2_w, cur),
-                            cur),
-                        ggml_repeat(ctx0, layer.norm2_b, cur));
+                cur = ggml_mul(ctx0, cur, layer.norm2_w);
+                cur = ggml_add_inplace(ctx0, cur, layer.norm2_b);
             }
 
             // fully connected
-            cur = ggml_mul_mat(ctx0,
-                    layer.mlp_lin1_w,
-                    cur);
-
-            cur = ggml_add(ctx0,
-                    ggml_repeat(ctx0, layer.mlp_lin1_b, cur),
-                    cur);
+            cur = ggml_mul_mat(ctx0, layer.mlp_lin1_w, cur);
+            cur = ggml_add_inplace(ctx0, cur, layer.mlp_lin1_b);
 
             // GELU activation
             cur = ggml_gelu(ctx0, cur);
 
             // projection
-            cur = ggml_mul_mat(ctx0,
-                    layer.mlp_lin2_w,
-                    cur);
-
-            cur = ggml_add(ctx0,
-                    ggml_repeat(ctx0, layer.mlp_lin2_b, cur),
-                    cur);
+            cur = ggml_mul_mat(ctx0, layer.mlp_lin2_w, cur);
+            cur = ggml_add_inplace(ctx0, cur, layer.mlp_lin2_b);
         }
 
         inpL = ggml_add(ctx0, cur, inpFF);
@@ -1494,8 +1468,7 @@ prompt_encoder_result sam_encode_prompt(
 
     prompt_encoder_result res;
     res.embd_prompt_sparse = embd_prompt_sparse;
-    res.embd_prompt_dense = embd_prompt_dense;
-
+    res.embd_prompt_dense  = embd_prompt_dense;
     return res;
 }
 
@@ -1514,19 +1487,13 @@ struct ggml_tensor* sam_decode_mask_transformer_attn(
     struct ggml_tensor * Vcur = {};
 
     Qcur = ggml_mul_mat(ctx0, attn.q_w, queries);
-    Qcur = ggml_add(ctx0,
-            ggml_repeat(ctx0, attn.q_b, Qcur),
-            Qcur);
+    Qcur = ggml_add_inplace(ctx0, Qcur, attn.q_b);
 
     Kcur = ggml_mul_mat(ctx0, attn.k_w, keys);
-    Kcur = ggml_add(ctx0,
-            ggml_repeat(ctx0, attn.k_b, Kcur),
-            Kcur);
+    Kcur = ggml_add_inplace(ctx0, Kcur, attn.k_b);
 
     Vcur = ggml_mul_mat(ctx0, attn.v_w, values);
-    Vcur = ggml_add(ctx0,
-            ggml_repeat(ctx0, attn.v_b, Vcur),
-            Vcur);
+    Vcur = ggml_add_inplace(ctx0, Vcur, attn.v_b);
 
     struct ggml_tensor * Q = {};
     struct ggml_tensor * K = {};
@@ -1557,9 +1524,7 @@ struct ggml_tensor* sam_decode_mask_transformer_attn(
     KQV_merged = ggml_cont(ctx0, ggml_permute(ctx0, KQV_merged, 0, 2, 1, 3));
     KQV_merged = ggml_reshape_3d(ctx0, KQV_merged, KQV_merged->ne[0]*KQV_merged->ne[1], KQV_merged->ne[2], KQV_merged->ne[3]);
     KQV_merged = ggml_mul_mat(ctx0, attn.out_w, KQV_merged);
-    KQV_merged = ggml_add(ctx0,
-                ggml_repeat(ctx0, attn.out_b, KQV_merged),
-                KQV_merged);
+    KQV_merged = ggml_add_inplace(ctx0, KQV_merged, attn.out_b);
 
     return KQV_merged;
 }
@@ -1576,21 +1541,17 @@ struct ggml_tensor * sam_decode_mask_mlp_relu_3(
 
     struct ggml_tensor * cur = {};
     cur = ggml_mul_mat(ctx0, w_0, in);
-    cur = ggml_add(ctx0,
-            ggml_repeat(ctx0, b_0, cur),
-            cur);
-    cur = ggml_relu(ctx0, cur);
+    cur = ggml_add_inplace(ctx0, cur, b_0);
+
+    cur = ggml_relu_inplace(ctx0, cur);
 
     cur = ggml_mul_mat(ctx0, w_1, cur);
-    cur = ggml_add(ctx0,
-            ggml_repeat(ctx0, b_1, cur),
-            cur);
-    cur = ggml_relu(ctx0, cur);
+    cur = ggml_add_inplace(ctx0, cur, b_1);
+
+    cur = ggml_relu_inplace(ctx0, cur);
 
     cur = ggml_mul_mat(ctx0, w_2, cur);
-    cur = ggml_add(ctx0,
-            ggml_repeat(ctx0, b_2, cur),
-            cur);
+    cur = ggml_add_inplace(ctx0, cur, b_2);
 
     return cur;
 }
@@ -1692,15 +1653,13 @@ bool sam_decode_mask(
                 struct ggml_tensor * q_0 = ggml_add(ctx0, queries, tokens);
 
                 struct ggml_tensor * self_attn = sam_decode_mask_transformer_attn(tfm_layer.self_attn, q_0, q_0, queries, ctx0, model);
-                queries = ggml_add(ctx0, queries, self_attn);
+                queries = ggml_add_inplace(ctx0, queries, self_attn);
             }
 
             queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer);
-            queries = ggml_add(ctx0,
-                    ggml_mul(ctx0,
-                        ggml_repeat(ctx0, tfm_layer.norm1_w, queries),
-                        queries),
-                    ggml_repeat(ctx0, tfm_layer.norm1_b, queries));
+            queries = ggml_add_inplace(ctx0,
+                    ggml_mul(ctx0, queries, tfm_layer.norm1_w),
+                    tfm_layer.norm1_b);
 
             // Cross attention block, tokens attending to image embedding
             // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L163
@@ -1709,13 +1668,11 @@ bool sam_decode_mask(
 
             struct ggml_tensor * cross_attn_token_to_img = sam_decode_mask_transformer_attn(tfm_layer.cross_attn_token_to_img, q_1, k_1, keys, ctx0, model);
 
-            queries = ggml_add(ctx0, queries, cross_attn_token_to_img);
-            queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer);
-            queries = ggml_add(ctx0,
-                    ggml_mul(ctx0,
-                        ggml_repeat(ctx0, tfm_layer.norm2_w, queries),
-                        queries),
-                    ggml_repeat(ctx0, tfm_layer.norm2_b, queries));
+            queries = ggml_add_inplace(ctx0, queries, cross_attn_token_to_img);
+            queries = ggml_norm_inplace(ctx0, queries, hparams.eps_decoder_transformer);
+            queries = ggml_add_inplace(ctx0,
+                    ggml_mul(ctx0, queries, tfm_layer.norm2_w),
+                    tfm_layer.norm2_b);
 
             // MLP block
             // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L170
@@ -1723,27 +1680,19 @@ bool sam_decode_mask(
                 tfm_layer.mlp_lin1_w,
                 queries);
 
-            mlp_out = ggml_add(ctx0,
-                    ggml_repeat(ctx0, tfm_layer.mlp_lin1_b, mlp_out),
-                    mlp_out);
+            mlp_out = ggml_add_inplace(ctx0, mlp_out, tfm_layer.mlp_lin1_b);
 
             // RELU activation
-            mlp_out = ggml_relu(ctx0, mlp_out);
-            mlp_out = ggml_mul_mat(ctx0,
-                    tfm_layer.mlp_lin2_w,
-                    mlp_out);
+            mlp_out = ggml_relu_inplace(ctx0, mlp_out);
+            mlp_out = ggml_mul_mat(ctx0, tfm_layer.mlp_lin2_w, mlp_out);
 
-            mlp_out = ggml_add(ctx0,
-                    ggml_repeat(ctx0, tfm_layer.mlp_lin2_b, mlp_out),
-                    mlp_out);
+            mlp_out = ggml_add_inplace(ctx0, mlp_out, tfm_layer.mlp_lin2_b);
 
-            queries = ggml_add(ctx0, queries, mlp_out);
-            queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer);
-            queries = ggml_add(ctx0,
-                    ggml_mul(ctx0,
-                        ggml_repeat(ctx0, tfm_layer.norm3_w, queries),
-                        queries),
-                    ggml_repeat(ctx0, tfm_layer.norm3_b, queries));
+            queries = ggml_add_inplace(ctx0, queries, mlp_out);
+            queries = ggml_norm_inplace(ctx0, queries, hparams.eps_decoder_transformer);
+            queries = ggml_add_inplace(ctx0,
+                    ggml_mul(ctx0, queries, tfm_layer.norm3_w),
+                    tfm_layer.norm3_b);
 
             // Cross attention block, image embedding attending to tokens
             // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/transformer.py#L175
@@ -1751,13 +1700,11 @@ bool sam_decode_mask(
             struct ggml_tensor * k_2 = ggml_add(ctx0, keys, pos_src);
 
             struct ggml_tensor * cross_attn_img_to_token = sam_decode_mask_transformer_attn(tfm_layer.cross_attn_img_to_token, k_2, q_2, queries, ctx0, model);
-            keys = ggml_add(ctx0, keys, cross_attn_img_to_token);
-            keys = ggml_norm(ctx0, keys, hparams.eps_decoder_transformer);
-            keys = ggml_add(ctx0,
-                    ggml_mul(ctx0,
-                        ggml_repeat(ctx0, tfm_layer.norm4_w, keys),
-                        keys),
-                    ggml_repeat(ctx0, tfm_layer.norm4_b, keys));
+            keys = ggml_add_inplace(ctx0, keys, cross_attn_img_to_token);
+            keys = ggml_norm_inplace(ctx0, keys, hparams.eps_decoder_transformer);
+            keys = ggml_add_inplace(ctx0,
+                    ggml_mul(ctx0, keys, tfm_layer.norm4_w),
+                    tfm_layer.norm4_b);
         }
 
         // Apply the final attention layer from the points to the image
@@ -1767,13 +1714,11 @@ bool sam_decode_mask(
 
         struct ggml_tensor * final_attn_token_to_img = sam_decode_mask_transformer_attn(dec.transformer_final_attn_token_to_img, q, k, keys, ctx0, model);
 
-        queries = ggml_add(ctx0, queries, final_attn_token_to_img);
-        queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer);
-        queries = ggml_add(ctx0,
-                ggml_mul(ctx0,
-                    ggml_repeat(ctx0, dec.transformer_norm_final_w, queries),
-                    queries),
-                ggml_repeat(ctx0, dec.transformer_norm_final_b, queries));
+        queries = ggml_add_inplace(ctx0, queries, final_attn_token_to_img);
+        queries = ggml_norm_inplace(ctx0, queries, hparams.eps_decoder_transformer);
+        queries = ggml_add_inplace(ctx0,
+                ggml_mul(ctx0, queries, dec.transformer_norm_final_w),
+                dec.transformer_norm_final_b);
     }
 
 
@@ -1790,23 +1735,23 @@ bool sam_decode_mask(
         // ConvTranspose2d
         keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_0_w, keys, 2);
         ggml_allocr_alloc(state.allocr, keys); // TODO: This alloc shouldn't be needed
-        keys = ggml_add(ctx0, keys, ggml_repeat(ctx0,
+        keys = ggml_add_inplace(ctx0, keys, ggml_repeat(ctx0,
                                      ggml_reshape_3d(ctx0, dec.output_upscaling_0_b, 1, 1, dec.output_upscaling_0_b->ne[0]),
                                      keys));
 
         keys = sam_layer_norm_2d(ctx0, keys, n_img_embd, dec.output_upscaling_1_w, dec.output_upscaling_1_b, hparams.eps);
 
         // GELU activation
-        keys = ggml_gelu(ctx0, keys);
+        keys = ggml_gelu_inplace(ctx0, keys);
 
         // ConvTranspose2d
         keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_3_w, keys, 2);
         ggml_allocr_alloc(state.allocr, keys); // TODO: This alloc shouldn't be needed
-        keys = ggml_add(ctx0, ggml_repeat(ctx0,
+        keys = ggml_add_inplace(ctx0, ggml_repeat(ctx0,
                                 ggml_reshape_3d(ctx0, dec.output_upscaling_3_b, 1, 1, dec.output_upscaling_3_b->ne[0]),
                                 keys), keys);
         // GELU activation
-        keys = ggml_gelu(ctx0, keys);
+        keys = ggml_gelu_inplace(ctx0, keys);
         upscaled_embedding = ggml_reshape_3d(ctx0, keys, keys->ne[0]*keys->ne[1], keys->ne[2], keys->ne[3]);
         upscaled_embedding = ggml_cont(ctx0, ggml_transpose(ctx0, upscaled_embedding)); // TODO: Shouldn't be needed
     }