]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sam : use ggml-alloc (#490)
authorYavor Ivanov <redacted>
Mon, 28 Aug 2023 12:40:23 +0000 (15:40 +0300)
committerGitHub <redacted>
Mon, 28 Aug 2023 12:40:23 +0000 (15:40 +0300)
examples/sam/README.md
examples/sam/main.cpp
src/ggml-alloc.c
src/ggml.c

index fa4b993f6de5d942e3a7af4278ad410e3d11457e..d8702f1d557bad661d95dde3cadbc94c102c600d 100644 (file)
@@ -8,15 +8,15 @@ The example currently supports only the [ViT-B SAM model checkpoint](https://hug
 
 ## Next steps
 
-- [ ] Reduce memory usage by utilizing the new ggml-alloc
+- [X] Reduce memory usage by utilizing the new ggml-alloc
 - [ ] Remove redundant graph nodes
 - [ ] Make inference faster
-- [ ] Fix the difference in output masks compared to the PyTorch implementation
+- [X] Fix the difference in output masks compared to the PyTorch implementation
 - [X] Filter masks based on stability score
 - [ ] Add support for user input
 - [ ] Support F16 for heavy F32 ops
 - [ ] Test quantization
-- [ ] Support bigger model checkpoints
+- [X] Support bigger model checkpoints
 - [ ] GPU support
 
 ## Quick start
index f57156911d957261f7d199d0c01c05185c48c08e..e8bfbb3de8545d15300dc816e8106fcc7d95a3a8 100644 (file)
@@ -1,9 +1,7 @@
 #define _USE_MATH_DEFINES // for M_PI
 
 #include "ggml.h"
-
-#include "common.h"
-
+#include "ggml-alloc.h"
 #define STB_IMAGE_IMPLEMENTATION
 #include "stb_image.h"
 #define STB_IMAGE_WRITE_IMPLEMENTATION
 #include <map>
 #include <string>
 #include <vector>
-
-// void print_t_f32(const char* title, struct ggml_tensor * t, int n = 10) {
-//     printf("%s\n", title);
-//     float * data = (float *)t->data;
-//     printf("dims: %jd %jd %jd %jd f32\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]);
-//     printf("First & Last %d elements:\n", n);
-//     for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) {
-//         printf("%.5f ", data[i]);
-//         if (i != 0 && i % t->ne[0] == 0) {
-//             printf("\n");
-//         }
-//     }
-//     printf("\n");
-//     for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) {
-//         printf("%.5f ", data[ggml_nelements(t) - n + i]);
-//         if ((ggml_nelements(t) - n + i) % t->ne[0] == 0) {
-//             printf("\n");
-//         }
-//     }
-//     printf("\n");
-//     double sum = 0.0;
-//     for (int i = 0; i < ggml_nelements(t); i++) {
-//         sum += data[i];
-//     }
-//     printf("sum:  %f\n\n", sum);
-// }
+#include <thread>
 
 // default hparams (ViT-B SAM)
 struct sam_hparams {
@@ -252,20 +225,37 @@ struct sam_decoder_mask {
     struct ggml_tensor * mask_tokens_w;
 };
 
+
 struct sam_state {
     struct ggml_tensor * embd_img;
-    struct ggml_tensor * embd_prompt_sparse;
-    struct ggml_tensor * embd_prompt_dense;
-    struct ggml_tensor * pe_img_dense;
 
     struct ggml_tensor * low_res_masks;
     struct ggml_tensor * iou_predictions;
 
+    //struct ggml_tensor * tmp_save = {};
+
     struct ggml_context * ctx;
 
-    std::vector<uint8_t> buf;
+    // buffer for `ggml_graph_plan.work_data`
+    std::vector<uint8_t> work_buffer;
+    // buffers to evaluate the model
+    std::vector<uint8_t> buf_alloc_img_enc;
+    std::vector<uint8_t> buf_compute_img_enc;
+
+    std::vector<uint8_t> buf_alloc_fast;
+    std::vector<uint8_t> buf_compute_fast;
+
+    struct ggml_allocr  * allocr = {};
 };
 
+// void save_tensor(sam_state& state, struct ggml_tensor * t, struct ggml_cgraph * gf) {
+//     if (!state.tmp_save) {
+//         state.tmp_save = ggml_new_tensor(state.ctx, t->type, t->n_dims, t->ne);
+//     }
+//     struct ggml_tensor * tmp0 = ggml_cpy(state.ctx, t, state.tmp_save);
+//     ggml_build_forward_expand(gf, tmp0);
+// }
+
 struct sam_model {
     sam_hparams hparams;
 
@@ -300,7 +290,51 @@ struct sam_image_f32 {
     std::vector<float> data;
 };
 
-void ggml_sam_sin(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) {
+void print_t_f32(const char* title, struct ggml_tensor * t, int n = 10) {
+    printf("%s\n", title);
+    float * data = (float *)t->data;
+    printf("dims: %jd %jd %jd %jd f32\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]);
+    printf("First & Last %d elements:\n", n);
+    for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) {
+        printf("%.5f ", data[i]);
+        if (i != 0 && i % t->ne[0] == 0) {
+            printf("\n");
+        }
+    }
+    printf("\n");
+    for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) {
+        printf("%.5f ", data[ggml_nelements(t) - n + i]);
+        if ((ggml_nelements(t) - n + i) % t->ne[0] == 0) {
+            printf("\n");
+        }
+    }
+    printf("\n");
+    double sum = 0.0;
+    for (int i = 0; i < ggml_nelements(t); i++) {
+        sum += data[i];
+    }
+    printf("sum:  %f\n\n", sum);
+}
+
+static void ggml_disconnect_node_from_graph(ggml_tensor * t) {
+    t->op = GGML_OP_NONE;
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        t->src[i] = NULL;
+    }
+}
+
+static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
+    struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
+
+    if (plan.work_size > 0) {
+        buf.resize(plan.work_size);
+        plan.work_data = buf.data();
+    }
+
+    ggml_graph_compute(graph, &plan);
+}
+
+static void ggml_sam_sin(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) {
     GGML_ASSERT(userdata == NULL);
     GGML_ASSERT(ggml_are_same_shape(dst, src));
     GGML_ASSERT(ggml_is_contiguous(dst));
@@ -319,7 +353,7 @@ void ggml_sam_sin(struct ggml_tensor * dst , const struct ggml_tensor * src, int
     }
 }
 
-void ggml_sam_cos(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) {
+static void ggml_sam_cos(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) {
     GGML_ASSERT(userdata == NULL);
     GGML_ASSERT(ggml_are_same_shape(dst, src));
     GGML_ASSERT(ggml_is_contiguous(dst));
@@ -1043,31 +1077,21 @@ bool sam_model_load(const std::string & fname, sam_model & model) {
     return true;
 }
 
-bool sam_fill_dense_pe(
-            const sam_model & model,
-                  sam_state & state,
-                        int   n_threads) {
+struct ggml_tensor * sam_fill_dense_pe(
+            const sam_model   & model,
+          struct ggml_context * ctx0,
+          struct ggml_cgraph  * gf,
+                  sam_state   & state) {
     const auto & hparams = model.hparams;
     const auto & enc     = model.enc_prompt;
 
     const int32_t n_img_embd = hparams.n_img_embd();
     const float n_img_embd_inv = 1.0f / n_img_embd;
 
-    static size_t buf_size = 256u*1024*1024;
-    static void * buf = malloc(buf_size);
-
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ buf_size,
-        /*.mem_buffer =*/ buf,
-        /*.no_alloc   =*/ false,
-    };
-
-    struct ggml_context * ctx0 = ggml_init(params);
-    struct ggml_cgraph gf = {};
-
     struct ggml_tensor * xy_embed_stacked = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 2, n_img_embd, n_img_embd);
+    ggml_allocr_alloc(state.allocr, xy_embed_stacked);
 
-    {
+    if (!ggml_allocr_is_measure(state.allocr)) {
         float * data = (float *) ggml_get_data(xy_embed_stacked);
         for (int i = 0; i < n_img_embd; ++i) {
             const int row = 2*i*n_img_embd;
@@ -1092,21 +1116,14 @@ bool sam_fill_dense_pe(
 
         cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, t_sin->ne[0] + t_cos->ne[0], cur->ne[1], cur->ne[2]);
 
-        ggml_build_forward_expand(&gf, ggml_cpy(ctx0, t_sin, ggml_view_3d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], t_sin->ne[2], cur->nb[1], cur->nb[2], 0)));
-        ggml_build_forward_expand(&gf, ggml_cpy(ctx0, t_cos, ggml_view_3d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], t_sin->ne[2], cur->nb[1], cur->nb[2], t_sin->nb[1])));
+        ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_sin, ggml_view_3d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], t_sin->ne[2], cur->nb[1], cur->nb[2], 0)));
+        ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_cos, ggml_view_3d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], t_sin->ne[2], cur->nb[1], cur->nb[2], t_sin->nb[1])));
     }
 
-    cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));
-
-    // TODO: avoid copy
-    cur = ggml_cpy(ctx0, cur, state.pe_img_dense);
+    struct ggml_tensor * pe_img_dense = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3));
+    ggml_build_forward_expand(gf, pe_img_dense);
 
-    // run the computation
-    ggml_set_name(cur, "check");
-    ggml_build_forward_expand(&gf, cur);
-    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
-
-    return true;
+    return pe_img_dense;
 }
 
 struct ggml_tensor* sam_layer_norm_2d(
@@ -1132,11 +1149,10 @@ struct ggml_tensor* sam_layer_norm_2d(
     return layer;
 }
 
-bool sam_encode_image(
+struct ggml_cgraph  * sam_encode_image(
             const sam_model & model,
                   sam_state & state,
-        const sam_image_f32 & img,
-                        int   n_threads) {
+        const sam_image_f32 & img) {
 
     const auto & hparams = model.hparams;
     const auto & enc     = model.enc_img;
@@ -1146,32 +1162,21 @@ bool sam_encode_image(
     const int32_t n_enc_head      = hparams.n_enc_head;
     const int32_t n_enc_head_dim  = hparams.n_enc_head_dim();
     const int32_t n_enc_out_chans = hparams.n_enc_out_chans;
-
     const int32_t n_img_size    = hparams.n_img_size();
     const int32_t n_window_size = hparams.n_window_size();
 
-    static size_t buf_size = 256u*1024*1024;
-    static void * buf = malloc(buf_size);
-
-    // use 2 scratch buffers
-    // TODO: very hacky solution - reimplement in a more elegant way
-    static size_t scr0_size = 2048u*1024*1024;
-    static void * scr0 = malloc(scr0_size);
-
-    static size_t scr1_size = 512u*1024*1024;
-    static void * scr1 = malloc(scr1_size);
-
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ buf_size,
-        /*.mem_buffer =*/ buf,
-        /*.no_alloc   =*/ false,
+    struct ggml_init_params ggml_params = {
+        /*.mem_size   =*/ state.buf_compute_img_enc.size(),
+        /*.mem_buffer =*/ state.buf_compute_img_enc.data(),
+        /*.no_alloc   =*/ true, // skip allocating as we use ggml_alloc to allocate exact memory requirements
     };
 
-    struct ggml_context * ctx0 = ggml_init(params);
-    struct ggml_cgraph gf = {};
+    struct ggml_context * ctx0   = ggml_init(ggml_params);
+    struct ggml_cgraph  * gf     = ggml_new_graph(ctx0);
 
     struct ggml_tensor * inp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_img_size, n_img_size, 3, 1);
-    {
+    ggml_allocr_alloc(state.allocr, inp);
+    if (!ggml_allocr_is_measure(state.allocr)) {
         float * data = (float *) ggml_get_data(inp);
 
         const int nx = img.nx;
@@ -1215,8 +1220,6 @@ bool sam_encode_image(
     for (int il = 0; il < n_enc_layer; ++il) {
         const auto & layer = enc.layers[il];
 
-        ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
-
         // norm
         // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L168
         {
@@ -1242,8 +1245,6 @@ bool sam_encode_image(
         const int64_t W = cur->ne[1];
         const int64_t H = cur->ne[2];
 
-        ggml_set_scratch(ctx0, { 0, scr1_size, scr1, });
-
         // self-attention
         {
             cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
@@ -1279,8 +1280,6 @@ bool sam_encode_image(
             V = ggml_cont      (ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); // transposed
             V = ggml_reshape_3d(ctx0, V,   W*H, n_enc_head_dim, B*n_enc_head);
 
-            ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
-
             struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
 
             struct ggml_tensor * KQ_scaled =
@@ -1330,8 +1329,6 @@ bool sam_encode_image(
 
         cur = ggml_add(ctx0, inpL, cur);
 
-        ggml_set_scratch(ctx0, { 0, scr1_size, scr1, });
-
         struct ggml_tensor * inpFF = cur;
 
         // feed-forward network
@@ -1373,8 +1370,6 @@ bool sam_encode_image(
         inpL = ggml_add(ctx0, cur, inpFF);
     }
 
-    ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
-
     cur = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 2, 0, 1, 3));
 
     cur = ggml_conv_2d_sk_p0(ctx0, enc.neck_conv_0, cur);
@@ -1385,23 +1380,24 @@ bool sam_encode_image(
 
     cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_1_w, enc.neck_norm_1_b, hparams.eps);
 
-    // TODO: avoid copy
     cur = ggml_cpy(ctx0, cur, state.embd_img);
 
-    ggml_set_name(cur, "check");
-
-    ggml_set_scratch(ctx0, { 0, 0, nullptr, });
-
-    // run the computation
-    ggml_build_forward_expand(&gf, cur);
-    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+    ggml_build_forward_expand(gf, cur);
+    ggml_disconnect_node_from_graph(state.embd_img);
 
     //ggml_graph_print(&gf);
 
     ggml_free(ctx0);
-    return true;
+
+    return gf;
 }
 
+
+struct prompt_encoder_result {
+    struct ggml_tensor * embd_prompt_sparse = {};
+    struct ggml_tensor * embd_prompt_dense = {};
+};
+
 // encode a prompt
 //
 // - points
@@ -1410,29 +1406,18 @@ bool sam_encode_image(
 //
 // TODO: currently just encode a single point for simplicity
 //
-bool sam_encode_prompt(
+prompt_encoder_result sam_encode_prompt(
         const sam_model     & model,
+        struct ggml_context * ctx0,
+        struct ggml_cgraph  * gf,
                   sam_state & state,
                         int   nx,
                         int   ny,
-                  sam_point   point,
-                        int   n_threads) {
+                  sam_point   point) {
 
     const auto & hparams = model.hparams;
     const auto & enc = model.enc_prompt;
 
-    static size_t buf_size = 256u*1024*1024;
-    static void * buf = malloc(buf_size);
-
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ buf_size,
-        /*.mem_buffer =*/ buf,
-        /*.no_alloc   =*/ false,
-    };
-
-    struct ggml_context * ctx0 = ggml_init(params);
-    struct ggml_cgraph gf = {};
-
     // transform points
     // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py#L276
     {
@@ -1447,12 +1432,11 @@ bool sam_encode_prompt(
         point.y = point.y*(float(ny_new)/ny) + 0.5f;
     }
 
-    printf("point: %f %f\n", point.x, point.y);
-
     struct ggml_tensor * inp = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2, 2);
 
-    // set the input by converting the [0, 1] coordinates to [-1, 1]
-    {
+    ggml_allocr_alloc(state.allocr, inp);
+    if (!ggml_allocr_is_measure(state.allocr)) {
+        // set the input by converting the [0, 1] coordinates to [-1, 1]
         float * data = (float *) inp->data;
 
         data[0] = 2.0f*(point.x / hparams.n_img_size()) - 1.0f;
@@ -1468,7 +1452,6 @@ bool sam_encode_prompt(
 
     cur = ggml_scale(ctx0, cur, ggml_new_f32(ctx0, 2.0f*M_PI));
 
-
     // concat
     // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L192
     {
@@ -1477,44 +1460,38 @@ bool sam_encode_prompt(
 
         cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, t_sin->ne[0] + t_cos->ne[0], cur->ne[1]);
 
-        ggml_build_forward_expand(&gf, ggml_cpy(ctx0, t_sin, ggml_view_2d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], cur->nb[1], 0)));
-        ggml_build_forward_expand(&gf, ggml_cpy(ctx0, t_cos, ggml_view_2d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], cur->nb[1], t_sin->nb[1])));
+        ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_sin, ggml_view_2d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], cur->nb[1], 0)));
+        ggml_build_forward_expand(gf, ggml_cpy(ctx0, t_cos, ggml_view_2d(ctx0, cur, t_sin->ne[0], t_sin->ne[1], cur->nb[1], t_sin->nb[1])));
 
         // overwrite label == -1 with not_a_point_embed.weight
         // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L86
         // TODO: extend for multiple points
-        ggml_build_forward_expand(&gf, ggml_cpy(ctx0, enc.not_a_pt_embd_w, ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], cur->nb[1])));
-
-        // add point_embeddings[1] to label == 1
-        // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L90
-        ggml_build_forward_expand(&gf, ggml_add_inplace(ctx0, ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], 0), enc.pt_embd[1]));
+        ggml_build_forward_expand(gf, ggml_cpy(ctx0, enc.not_a_pt_embd_w, ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], cur->nb[1])));
     }
 
-    // TODO: avoid copy
-    cur = ggml_cpy(ctx0, cur, state.embd_prompt_sparse);
+    // add point_embeddings[1] to label == 1
+    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L90
+    struct ggml_tensor * v = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], 0);
+    ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_add_inplace(ctx0, v, enc.pt_embd[1]), v));
 
-    ggml_build_forward_expand(&gf, cur);
+    struct ggml_tensor * embd_prompt_sparse = cur;
+    ggml_build_forward_expand(gf, embd_prompt_sparse);
 
-    cur = ggml_repeat(ctx0,
+    struct ggml_tensor * embd_prompt_dense = ggml_repeat(ctx0,
             ggml_cont(ctx0,
                 ggml_view_3d(ctx0, enc.no_mask_embd_w,
                     1, 1, enc.no_mask_embd_w->ne[0], enc.no_mask_embd_w->nb[0], enc.no_mask_embd_w->nb[0], 0)),
-            state.embd_prompt_dense);
-
-    // TODO: avoid copy
-    cur = ggml_cpy(ctx0, cur, state.embd_prompt_dense);
-
-    ggml_build_forward_expand(&gf, cur);
+            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hparams.n_img_embd(), hparams.n_img_embd(), hparams.n_enc_out_chans));
 
-    ggml_set_name(cur, "check");
-
-    // run the computation
-    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+    ggml_build_forward_expand(gf, embd_prompt_dense);
 
     //printf("used_mem = %zu\n", ggml_used_mem(ctx0));
 
-    ggml_free(ctx0);
-    return true;
+    prompt_encoder_result res;
+    res.embd_prompt_sparse = embd_prompt_sparse;
+    res.embd_prompt_dense = embd_prompt_dense;
+
+    return res;
 }
 
 struct ggml_tensor* sam_decode_mask_transformer_attn(
@@ -1546,9 +1523,6 @@ struct ggml_tensor* sam_decode_mask_transformer_attn(
             ggml_repeat(ctx0, attn.v_b, Vcur),
             Vcur);
 
-    // TODO: use stratch memory
-    // ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
-
     struct ggml_tensor * Q = {};
     struct ggml_tensor * K = {};
     struct ggml_tensor * V = {};
@@ -1617,46 +1591,33 @@ struct ggml_tensor * sam_decode_mask_mlp_relu_3(
 }
 
 bool sam_decode_mask(
-        const sam_model     & model,
-                  sam_state & state,
-                        int   n_threads) {
+                    const sam_model & model,
+        const prompt_encoder_result & prompt,
+                 struct ggml_tensor * pe_img,
+                struct ggml_context * ctx0,
+                struct ggml_cgraph  * gf,
+                          sam_state & state) {
 
     const auto & hparams = model.hparams;
     const auto & dec = model.dec;
     const int n_img_embd = hparams.n_img_embd();
 
-    static size_t buf_size = 384u*1024*1024;
-    static void * buf = malloc(buf_size);
-
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ buf_size,
-        /*.mem_buffer =*/ buf,
-        /*.no_alloc   =*/ false,
-    };
-
-    struct ggml_context * ctx0 = ggml_init(params);
-    struct ggml_cgraph gf = {};
-
-    // print_t_f32("embd_img", state.embd_img);
-    // print_t_f32("embd_prompt_dense", state.embd_prompt_dense);
-    // print_t_f32("embd_prompt_sparse", state.embd_prompt_sparse);
-    // print_t_f32("pe_img_dense", state.pe_img_dense);
-
     struct ggml_tensor * tokens = {};
     {
         // Concatenate output tokens
         // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L120
-        const auto& sparse = state.embd_prompt_sparse;
+        const auto& sparse = prompt.embd_prompt_sparse;
 
         tokens = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, dec.iou_token_w->ne[0], dec.iou_token_w->ne[1] + dec.mask_tokens_w->ne[1] + sparse->ne[1], sparse->ne[2]);
 
         const size_t offsets[3] = { 0, dec.iou_token_w->ne[1]*tokens->nb[1], dec.iou_token_w->ne[1]*tokens->nb[1] + dec.mask_tokens_w->ne[1]*tokens->nb[1] };
-        ggml_build_forward_expand(&gf, ggml_cpy(ctx0, dec.iou_token_w,   ggml_view_2d(ctx0, tokens, tokens->ne[0], dec.iou_token_w->ne[1],   tokens->nb[1], offsets[0])));
-        ggml_build_forward_expand(&gf, ggml_cpy(ctx0, dec.mask_tokens_w, ggml_view_2d(ctx0, tokens, tokens->ne[0], dec.mask_tokens_w->ne[1], tokens->nb[1], offsets[1])));
-        ggml_build_forward_expand(&gf, ggml_cpy(ctx0, sparse,            ggml_view_2d(ctx0, tokens, tokens->ne[0], sparse->ne[1],            tokens->nb[1], offsets[2])));
+        ggml_build_forward_expand(gf, ggml_cpy(ctx0, dec.iou_token_w,   ggml_view_2d(ctx0, tokens, tokens->ne[0], dec.iou_token_w->ne[1],   tokens->nb[1], offsets[0])));
+        ggml_build_forward_expand(gf, ggml_cpy(ctx0, dec.mask_tokens_w, ggml_view_2d(ctx0, tokens, tokens->ne[0], dec.mask_tokens_w->ne[1], tokens->nb[1], offsets[1])));
+        ggml_build_forward_expand(gf, ggml_cpy(ctx0, sparse,            ggml_view_2d(ctx0, tokens, tokens->ne[0], sparse->ne[1],            tokens->nb[1], offsets[2])));
         // TODO: Sparse prompt embeddings can have more than one point
     }
 
+
     struct ggml_tensor * src = {};
     struct ggml_tensor * pos_src = {};
     int srcNE[4] = { 0, 0, 0, 0 };
@@ -1664,11 +1625,12 @@ bool sam_decode_mask(
         // Expand per-image data in the batch direction to be per-mask
         // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L125
         src = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, state.embd_img->ne[0], state.embd_img->ne[1], state.embd_img->ne[2], tokens->ne[2]);
+
         src = ggml_add(ctx0,
             ggml_repeat(ctx0,
                 state.embd_img,
                 src),
-            state.embd_prompt_dense);
+            prompt.embd_prompt_dense);
 
         srcNE[0] = src->ne[0];
         srcNE[1] = src->ne[1];
@@ -1687,11 +1649,10 @@ bool sam_decode_mask(
                 src->nb[3],
                 0),
             1, 0, 2, 3));
-        ggml_build_forward_expand(&gf, src);
 
-        pos_src = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, state.pe_img_dense->ne[0], state.pe_img_dense->ne[1], state.pe_img_dense->ne[2], tokens->ne[2]);
+        pos_src = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, pe_img->ne[0], pe_img->ne[1], pe_img->ne[2], tokens->ne[2]);
         pos_src = ggml_repeat(ctx0,
-            state.pe_img_dense,
+            pe_img,
             pos_src);
 
         // flatten & permute
@@ -1706,7 +1667,6 @@ bool sam_decode_mask(
                 pos_src->nb[3],
                 0),
             1, 0, 2, 3));
-        ggml_build_forward_expand(&gf, pos_src);
     }
 
     struct ggml_tensor * queries = tokens;
@@ -1811,6 +1771,7 @@ bool sam_decode_mask(
                 ggml_repeat(ctx0, dec.transformer_norm_final_b, queries));
     }
 
+
     struct ggml_tensor * iou_pred = ggml_view_2d(ctx0, queries, queries->ne[0], queries->ne[2], queries->nb[2], 0);
     const int num_mask_tokens = 4; // num_multimask_outputs + 1
     struct ggml_tensor * mask_tokens_out = ggml_view_3d(ctx0, queries, queries->ne[0], num_mask_tokens, queries->ne[2], queries->nb[1], num_mask_tokens*queries->nb[1], queries->nb[1]);
@@ -1819,11 +1780,11 @@ bool sam_decode_mask(
     // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L136
     keys = ggml_cont(ctx0, ggml_transpose(ctx0, keys));
     keys = ggml_view_4d(ctx0, keys, srcNE[0], srcNE[1], srcNE[2], srcNE[3], srcNE[0]*keys->nb[0], keys->nb[1], keys->nb[2], 0);
-
     struct ggml_tensor * upscaled_embedding = {};
     {
         // ConvTranspose2d
         keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_0_w, keys, 2);
+        ggml_allocr_alloc(state.allocr, keys); // TODO: This alloc shouldn't be needed
         keys = ggml_add(ctx0, keys, ggml_repeat(ctx0,
                                      ggml_reshape_3d(ctx0, dec.output_upscaling_0_b, 1, 1, dec.output_upscaling_0_b->ne[0]),
                                      keys));
@@ -1835,6 +1796,7 @@ bool sam_decode_mask(
 
         // ConvTranspose2d
         keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_3_w, keys, 2);
+        ggml_allocr_alloc(state.allocr, keys); // TODO: This alloc shouldn't be needed
         keys = ggml_add(ctx0, ggml_repeat(ctx0,
                                 ggml_reshape_3d(ctx0, dec.output_upscaling_3_b, 1, 1, dec.output_upscaling_3_b->ne[0]),
                                 keys), keys);
@@ -1845,11 +1807,12 @@ bool sam_decode_mask(
     }
 
     struct ggml_tensor * hyper_in = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_img_embd/2, num_mask_tokens, mask_tokens_out->ne[2]);
+
     for (int i = 0; i < num_mask_tokens; ++i) {
         const auto& mlp = dec.output_hypernet_mlps[i];
         struct ggml_tensor * in = ggml_view_2d(ctx0, mask_tokens_out, mask_tokens_out->ne[0], mask_tokens_out->ne[2], mask_tokens_out->nb[1], i*mask_tokens_out->nb[1]);
         struct ggml_tensor * out = sam_decode_mask_mlp_relu_3(in, mlp.w_0, mlp.b_0, mlp.w_1, mlp.b_1, mlp.w_2, mlp.b_2, ctx0);
-        ggml_build_forward_expand(&gf, ggml_cpy(ctx0, out, ggml_view_2d(ctx0, hyper_in, hyper_in->ne[0], hyper_in->ne[2], hyper_in->nb[1], i*hyper_in->nb[1])));
+        ggml_build_forward_expand(gf, ggml_cpy(ctx0, out, ggml_view_2d(ctx0, hyper_in, hyper_in->ne[0], hyper_in->ne[2], hyper_in->nb[1], i*hyper_in->nb[1])));
     }
 
     struct ggml_tensor * masks = ggml_mul_mat(ctx0, hyper_in, upscaled_embedding);
@@ -1858,39 +1821,21 @@ bool sam_decode_mask(
 
     // Generate mask quality predictions
     // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L146
-    state.iou_predictions = sam_decode_mask_mlp_relu_3(iou_pred, dec.iou_prediction_head_0_w, dec.iou_prediction_head_0_b, dec.iou_prediction_head_1_w, dec.iou_prediction_head_1_b, dec.iou_prediction_head_2_w, dec.iou_prediction_head_2_b, ctx0);
+    iou_pred = sam_decode_mask_mlp_relu_3(iou_pred, dec.iou_prediction_head_0_w, dec.iou_prediction_head_0_b, dec.iou_prediction_head_1_w, dec.iou_prediction_head_1_b, dec.iou_prediction_head_2_w, dec.iou_prediction_head_2_b, ctx0);
 
     // Select the correct mask or masks for output
     // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L101
-    state.iou_predictions = ggml_view_1d(ctx0, state.iou_predictions, state.iou_predictions->ne[0] - 1, state.iou_predictions->nb[0]);
-
-    state.low_res_masks = ggml_view_4d(ctx0, masks, masks->ne[0], masks->ne[1], masks->ne[2] - 1, masks->ne[3],
-                                                    masks->nb[0], masks->nb[1], masks->nb[2],
-                                                    masks->nb[2] /* offset*/);
-    // ggml_set_name(queries, "queries");
-    // ggml_set_name(upscaled_embedding, "upscaled_embedding");
-    // ggml_set_name(state.low_res_masks, "low_res_masks");
-    // ggml_set_name(state.iou_predictions, "iou_predictions");
-    // ggml_set_name(hyper_in, "hyper_in");
-
-    ggml_build_forward_expand(&gf, state.iou_predictions);
-    ggml_build_forward_expand(&gf, state.low_res_masks);
-
-    // run the computation
-    ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
-
-    // auto * t = ggml_get_tensor(ctx0, "queries");
-    // print_t_f32("queries", t);
-    // t = ggml_get_tensor(ctx0, "upscaled_embedding");
-    // print_t_f32("upscaled_embedding", t);
-    // t = ggml_get_tensor(ctx0, "low_res_masks");
-    // print_t_f32("low_res_masks", t);
-    // t = ggml_get_tensor(ctx0, "iou_predictions");
-    // print_t_f32("iou_predictions", t);
-    // t = ggml_get_tensor(ctx0, "hyper_in");
-    // print_t_f32("hyper_in", t);
+    iou_pred = ggml_cpy(state.ctx, ggml_view_1d(ctx0, iou_pred, iou_pred->ne[0] - 1, iou_pred->nb[0]), state.iou_predictions);
+    masks = ggml_view_4d(ctx0, masks, masks->ne[0], masks->ne[1], masks->ne[2] - 1, masks->ne[3],
+                                      masks->nb[1], masks->nb[2], masks->nb[3], masks->nb[2] /* offset*/);
+    masks = ggml_cpy(state.ctx, masks, state.low_res_masks);
+
+    ggml_build_forward_expand(gf, masks);
+    ggml_build_forward_expand(gf, iou_pred);
+
+    ggml_disconnect_node_from_graph(state.low_res_masks);
+    ggml_disconnect_node_from_graph(state.iou_predictions);
 
-    ggml_free(ctx0);
     return true;
 }
 
@@ -1929,6 +1874,7 @@ bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state
 
     for (int i = 0; i < ne2; ++i) {
         if (iou_threshold > 0.f && iou_data[i] < iou_threshold) {
+            printf("Skipping mask %d with iou %f below threshold %f\n", i, iou_data[i], iou_threshold);
             continue; // Filtering masks with iou below the threshold
         }
 
@@ -2033,6 +1979,7 @@ bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state
 
         const float stability_score = float(intersections) / float(unions);
         if (stability_score_threshold > 0.f && stability_score < stability_score_threshold) {
+            printf("Skipping mask %d with stability score %f below threshold %f\n", i, stability_score, stability_score_threshold);
             continue; // Filtering masks with stability score below the threshold
         }
 
@@ -2050,12 +1997,105 @@ bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state
     return true;
 }
 
+struct ggml_cgraph  * sam_build_fast_graph(
+        const sam_model     & model,
+                  sam_state & state,
+                        int   nx,
+                        int   ny,
+                  sam_point   point) {
+
+    struct ggml_init_params ggml_params = {
+        /*.mem_size   =*/ state.buf_compute_fast.size(),
+        /*.mem_buffer =*/ state.buf_compute_fast.data(),
+        /*.no_alloc   =*/ true, // skip allocating as we use ggml_alloc to allocate exact memory requirements
+    };
+
+    struct ggml_context * ctx0   = ggml_init(ggml_params);
+    struct ggml_cgraph  * gf     = ggml_new_graph(ctx0);
+
+    prompt_encoder_result enc_res = sam_encode_prompt(model, ctx0, gf, state, nx, ny, point);
+    if (!enc_res.embd_prompt_sparse || !enc_res.embd_prompt_dense) {
+        fprintf(stderr, "%s: failed to encode prompt\n", __func__);
+        return {};
+    }
+
+    struct ggml_tensor * pe_img_dense = sam_fill_dense_pe(model, ctx0, gf, state);
+    if (!pe_img_dense) {
+        fprintf(stderr, "%s: failed to get dense positional encoding\n", __func__);
+        return {};
+    }
+
+    if (!sam_decode_mask(model, enc_res, pe_img_dense, ctx0, gf, state)) {
+         fprintf(stderr, "%s: failed to decode mask\n", __func__);
+         return {};
+    }
+
+    ggml_free(ctx0);
+
+    return gf;
+}
+struct sam_params {
+    int32_t seed      = -1; // RNG seed
+    int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
+
+    std::string model     = "models/sam-vit-b/ggml-model-f16.bin"; // model path
+    std::string fname_inp = "img.jpg";
+    std::string fname_out = "img.out";
+};
+
+void sam_print_usage(int argc, char ** argv, const sam_params & params) {
+    fprintf(stderr, "usage: %s [options]\n", argv[0]);
+    fprintf(stderr, "\n");
+    fprintf(stderr, "options:\n");
+    fprintf(stderr, "  -h, --help            show this help message and exit\n");
+    fprintf(stderr, "  -s SEED, --seed SEED  RNG seed (default: -1)\n");
+    fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
+    fprintf(stderr, "  -m FNAME, --model FNAME\n");
+    fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str());
+    fprintf(stderr, "  -i FNAME, --inp FNAME\n");
+    fprintf(stderr, "                        input file (default: %s)\n", params.fname_inp.c_str());
+    fprintf(stderr, "  -o FNAME, --out FNAME\n");
+    fprintf(stderr, "                        output file (default: %s)\n", params.fname_out.c_str());
+    fprintf(stderr, "\n");
+}
+
+bool sam_params_parse(int argc, char ** argv, sam_params & params) {
+    for (int i = 1; i < argc; i++) {
+        std::string arg = argv[i];
+
+        if (arg == "-s" || arg == "--seed") {
+            params.seed = std::stoi(argv[++i]);
+        } else if (arg == "-t" || arg == "--threads") {
+            params.n_threads = std::stoi(argv[++i]);
+        } else if (arg == "-m" || arg == "--model") {
+            params.model = argv[++i];
+        } else if (arg == "-i" || arg == "--inp") {
+            params.fname_inp = argv[++i];
+        } else if (arg == "-o" || arg == "--out") {
+            params.fname_out = argv[++i];
+        } else if (arg == "-h" || arg == "--help") {
+            sam_print_usage(argc, argv, params);
+            exit(0);
+        } else {
+            fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
+            sam_print_usage(argc, argv, params);
+            exit(0);
+        }
+    }
+
+    return true;
+}
+
 int main(int argc, char ** argv) {
     const int64_t t_main_start_us = ggml_time_us();
 
     sam_params params;
     params.model = "models/sam-vit-b/ggml-model-f16.bin";
 
+    sam_model model;
+    sam_state state;
+    int64_t t_load_us = 0;
+
     if (sam_params_parse(argc, argv, params) == false) {
         return 1;
     }
@@ -2063,7 +2103,6 @@ int main(int argc, char ** argv) {
     if (params.seed < 0) {
         params.seed = time(NULL);
     }
-
     fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
 
     // load the image
@@ -2072,7 +2111,6 @@ int main(int argc, char ** argv) {
         fprintf(stderr, "%s: failed to load image from '%s'\n", __func__, params.fname_inp.c_str());
         return 1;
     }
-
     fprintf(stderr, "%s: loaded image '%s' (%d x %d)\n", __func__, params.fname_inp.c_str(), img0.nx, img0.ny);
 
     // preprocess to f32
@@ -2081,24 +2119,8 @@ int main(int argc, char ** argv) {
         fprintf(stderr, "%s: failed to preprocess image\n", __func__);
         return 1;
     }
-
     fprintf(stderr, "%s: preprocessed image (%d x %d)\n", __func__, img1.nx, img1.ny);
 
-#if 0
-    {
-        const int n = 128;
-        fprintf(stderr, "%s: first %d diagonal pixels:\n", __func__, n);
-        for (int i = 0; i < n; i++) {
-            const int ii = i*img1.nx + i;
-            fprintf(stderr, "%s:   %d: %f %f %f\n", __func__, i, img1.data[3*ii + 0], img1.data[3*ii + 1], img1.data[3*ii + 2]);
-        }
-    }
-#endif
-
-    int64_t t_load_us = 0;
-
-    sam_model model;
-    sam_state state;
 
     // load the model
     {
@@ -2115,48 +2137,97 @@ int main(int argc, char ** argv) {
     {
         static size_t buf_size = 256u*1024*1024;
 
-        struct ggml_init_params params = {
+        struct ggml_init_params ggml_params = {
             /*.mem_size   =*/ buf_size,
             /*.mem_buffer =*/ NULL,
             /*.no_alloc   =*/ false,
         };
 
-        state.ctx = ggml_init(params);
+        state.ctx = ggml_init(ggml_params);
 
         state.embd_img = ggml_new_tensor_3d(state.ctx, GGML_TYPE_F32,
                 model.hparams.n_img_embd(), model.hparams.n_img_embd(), model.hparams.n_enc_out_chans);
 
-        // TODO: should depend on the number of points / boxes / etc
-        state.embd_prompt_sparse = ggml_new_tensor_2d(state.ctx, GGML_TYPE_F32, model.hparams.n_enc_out_chans, 2);
+        state.low_res_masks = ggml_new_tensor_3d(state.ctx, GGML_TYPE_F32,
+                model.hparams.n_enc_out_chans, model.hparams.n_enc_out_chans, 3);
 
-        state.embd_prompt_dense  = ggml_new_tensor_3d(state.ctx, GGML_TYPE_F32,
-                model.hparams.n_img_embd(), model.hparams.n_img_embd(), model.hparams.n_enc_out_chans);
-
-        state.pe_img_dense = ggml_new_tensor_3d(state.ctx, GGML_TYPE_F32,
-                model.hparams.n_img_embd(), model.hparams.n_img_embd(), model.hparams.n_enc_out_chans);
+        state.iou_predictions = ggml_new_tensor_1d(state.ctx, GGML_TYPE_F32, 3);
     }
 
-    if (!sam_fill_dense_pe(model, state, params.n_threads)) {
-        fprintf(stderr, "%s: failed to get dense positional encoding\n", __func__);
-        return 1;
-    }
 
-    if (!sam_encode_image(model, state, img1, params.n_threads)) {
-        fprintf(stderr, "%s: failed to encode image\n", __func__);
-        return 1;
-    }
+    static const size_t tensor_alignment = 32;
+    {
+        state.buf_compute_img_enc.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
+        state.allocr = ggml_allocr_new_measure(tensor_alignment);
+        struct ggml_cgraph * gf_measure = sam_encode_image(model, state, img1);
+        if (!gf_measure) {
+            fprintf(stderr, "%s: failed to encode image\n", __func__);
+            return 1;
+        }
 
-    // TODO: user input
-    const sam_point pt = { 414.375f, 162.796875f, };
+        size_t alloc_size = ggml_allocr_alloc_graph(state.allocr, gf_measure) + tensor_alignment;
+        ggml_allocr_free(state.allocr);
 
-    if (!sam_encode_prompt(model, state, img0.nx, img0.ny, pt, params.n_threads)) {
-        fprintf(stderr, "%s: failed to encode prompt\n", __func__);
-        return 1;
+        // recreate allocator with exact memory requirements
+        state.buf_alloc_img_enc.resize(alloc_size);
+        state.allocr = ggml_allocr_new(state.buf_alloc_img_enc.data(), state.buf_alloc_img_enc.size(), tensor_alignment);
+
+        // compute the graph with the measured exact memory requirements from above
+        ggml_allocr_reset(state.allocr);
+
+        struct ggml_cgraph  * gf = sam_encode_image(model, state, img1);
+        if (!gf) {
+            fprintf(stderr, "%s: failed to encode image\n", __func__);
+            return 1;
+        }
+
+        ggml_allocr_alloc_graph(state.allocr, gf);
+
+        ggml_graph_compute_helper(state.work_buffer, gf, params.n_threads);
+
+        print_t_f32("embd_img", state.embd_img);
+
+        ggml_allocr_free(state.allocr);
+        state.allocr = NULL;
+        state.work_buffer.clear();
     }
+    {
+        state.buf_compute_fast.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
+        state.allocr = ggml_allocr_new_measure(tensor_alignment);
+
+        // TODO: user input
+        const sam_point pt = { 414.375f, 162.796875f, };
+        // measure memory requirements for the graph
+        struct ggml_cgraph  * gf_measure = sam_build_fast_graph(model, state, img0.nx, img0.ny, pt);
+        if (!gf_measure) {
+            fprintf(stderr, "%s: failed to build fast graph to measure\n", __func__);
+            return 1;
+        }
 
-    if (!sam_decode_mask(model, state, params.n_threads)) {
-        fprintf(stderr, "%s: failed to decode mask\n", __func__);
-        return 1;
+        size_t alloc_size = ggml_allocr_alloc_graph(state.allocr, gf_measure) + tensor_alignment;
+        ggml_allocr_free(state.allocr);
+
+        // recreate allocator with exact memory requirements
+        state.buf_alloc_fast.resize(alloc_size);
+        state.allocr = ggml_allocr_new(state.buf_alloc_fast.data(), state.buf_alloc_fast.size(), tensor_alignment);
+
+        // compute the graph with the measured exact memory requirements from above
+        ggml_allocr_reset(state.allocr);
+
+        struct ggml_cgraph  * gf = sam_build_fast_graph(model, state, img0.nx, img0.ny, pt);
+        if (!gf) {
+            fprintf(stderr, "%s: failed to build fast graph\n", __func__);
+            return 1;
+        }
+
+        ggml_allocr_alloc_graph(state.allocr, gf);
+
+        ggml_graph_compute_helper(state.work_buffer, gf, params.n_threads);
+
+        //print_t_f32("iou_predictions", state.iou_predictions);
+        //print_t_f32("low_res_masks", state.low_res_masks);
+        ggml_allocr_free(state.allocr);
+        state.allocr = NULL;
     }
 
     if (!sam_write_masks(model.hparams, img0.nx, img0.ny, state)) {
index 140e9a2a7370a6fb6c32f88f6abbf3f4ae081979..856a4cdbc613405bcb367e9a4ca86ef297d31d06 100644 (file)
@@ -378,6 +378,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
         case GGML_OP_SET:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_CONT:
+        case GGML_OP_ADD_REL_POS:
             return true;
 
         default:
index dadb30757a962c7d5db44a0d854403f7dd5b8c21..b199bb1a6a3e31c4d92ab5937025f17826f0964b 100644 (file)
@@ -7464,8 +7464,6 @@ static struct ggml_tensor * ggml_add_rel_pos_impl(
     }
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-    ggml_set_op_params_i32(result, 0, inplace ? 1 : 0);
-
     result->op   = GGML_OP_ADD_REL_POS;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
@@ -14978,8 +14976,11 @@ static void ggml_compute_forward_add_rel_pos_f32(
         const struct ggml_tensor * src1,
         const struct ggml_tensor * src2,
         struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(src0->nb[0] == dst->nb[0] && src0->nb[1] == dst->nb[1]
+             && src0->nb[2] == dst->nb[2] && src0->nb[3] == dst->nb[3]);
 
-    const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
+    const bool inplace = dst->data == src0->data;
     if (!inplace && params->type == GGML_TASK_INIT) {
         memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
         return;