]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : improve ADD_REL_POS perf in SAM by doing it inplace + broadcast BLAS mul_mat...
authorYavor Ivanov <redacted>
Mon, 21 Aug 2023 12:31:27 +0000 (15:31 +0300)
committerGitHub <redacted>
Mon, 21 Aug 2023 12:31:27 +0000 (15:31 +0300)
* Improve ADD_REL_POS perf in SAM by doing it inplace

- Add unit tests for the ADD_REL_POS operation
- I am not sure if this is valid implementation as we reuse the src0
  memory in order to avoid copying it
- When running SAM with the "Example output" command, image, point and
  16 threads, this reduces the cumulative time of the ADD_REL_POS operation
  from 1000-1100 ms to 180-200ms
- There is further room for optimization in the access patterns used in
  the implementation of the opration

* Add non-inplace version for the GGML_OP_ADD_REL_POS

* Fix map_unary warnings and refactor LayerNorm2d + remove ggml_cont in it

* Fix Mac printf format warnings

* sam : add ggml_graph_print() comment

* ggml : add broadcast support for BLAS ggml_mul_mat() (#460)

* Remove not needed build_forward_expand from add-rel-pos unit test

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/sam/main.cpp
include/ggml/ggml.h
src/ggml.c
tests/CMakeLists.txt
tests/test-rel-pos.c [new file with mode: 0644]

index 320907cc94139b648df6e795f0c128092b68359a..9f6cd838290af00db9552f69cac28e23d921032c 100644 (file)
@@ -269,15 +269,41 @@ struct sam_image_f32 {
     std::vector<float> data;
 };
 
-void ggml_sam_sin(const int n, float * dst, const float * src) {
-    for (int i = 0; i < n; ++i) {
-        dst[i] = sinf(src[i]);
+void ggml_sam_sin(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) {
+    GGML_ASSERT(userdata == NULL);
+    GGML_ASSERT(ggml_are_same_shape(dst, src));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_is_contiguous(src));
+
+    const float * src_data = ggml_get_data_f32(src);
+    float * dst_data = ggml_get_data_f32(dst);
+
+    const int ne = (int)ggml_nelements(dst);
+    const int dr = (ne + nth - 1) / nth;
+    const int ie0 = dr * ith;
+    const int ie1 = std::min(ie0 + dr, ne);
+
+    for (int i = ie0; i < ie1; ++i) {
+        dst_data[i] = sinf(src_data[i]);
     }
 }
 
-void ggml_sam_cos(const int n, float * dst, const float * src) {
-    for (int i = 0; i < n; ++i) {
-        dst[i] = cosf(src[i]);
+void ggml_sam_cos(struct ggml_tensor * dst , const struct ggml_tensor * src, int ith, int nth, void * userdata) {
+    GGML_ASSERT(userdata == NULL);
+    GGML_ASSERT(ggml_are_same_shape(dst, src));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_is_contiguous(src));
+
+    const float * src_data = ggml_get_data_f32(src);
+    float * dst_data = ggml_get_data_f32(dst);
+
+    const int ne = (int)ggml_nelements(dst);
+    const int dr = (ne + nth - 1) / nth;
+    const int ie0 = dr * ith;
+    const int ie1 = std::min(ie0 + dr, ne);
+
+    for (int i = ie0; i < ie1; ++i) {
+        dst_data[i] = cosf(src_data[i]);
     }
 }
 
@@ -888,13 +914,6 @@ bool sam_model_load(const std::string & fname, sam_model & model) {
         }
     }
 
-    // key + value memory
-    {
-        // const auto & hparams = model.hparams;
-
-        // TODO
-    }
-
     // load weights
     {
         int n_tensors = 0;
@@ -1037,8 +1056,8 @@ bool sam_fill_dense_pe(
     // concat
     // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L192
     {
-        struct ggml_tensor * t_sin = ggml_map_unary_f32(ctx0, cur, ggml_sam_sin);
-        struct ggml_tensor * t_cos = ggml_map_unary_f32(ctx0, cur, ggml_sam_cos);
+        struct ggml_tensor * t_sin = ggml_map_custom1(ctx0, cur, ggml_sam_sin, GGML_N_TASKS_MAX, NULL);
+        struct ggml_tensor * t_cos = ggml_map_custom1(ctx0, cur, ggml_sam_cos, GGML_N_TASKS_MAX, NULL);
 
         cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, t_sin->ne[0] + t_cos->ne[0], cur->ne[1], cur->ne[2]);
 
@@ -1059,6 +1078,28 @@ bool sam_fill_dense_pe(
     return true;
 }
 
+struct ggml_tensor* sam_layer_norm_2d(
+                    struct ggml_context * ctx0,
+                    struct ggml_tensor  * layer,
+                    int                   n_channels,
+                    struct ggml_tensor  * w,
+                    struct ggml_tensor  * b) {
+    // LayerNorm2d
+    // normalize along channel dimmension
+    // TODO: better implementation
+    layer = ggml_permute(ctx0,
+                ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, layer, 1, 2, 0, 3))),
+                2, 0, 1, 3);
+
+    layer = ggml_add(ctx0,
+              ggml_mul(ctx0,
+                  ggml_repeat(ctx0, ggml_reshape_3d(ctx0, w, 1, 1, n_channels), layer),
+                  layer),
+              ggml_repeat(ctx0, ggml_reshape_3d(ctx0, b, 1, 1, n_channels), layer));
+
+    return layer;
+}
+
 bool sam_encode_image(
             const sam_model & model,
                   sam_state & state,
@@ -1228,7 +1269,7 @@ bool sam_encode_image(
                         0, 2, 1, 3));
             struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r);
 
-            struct ggml_tensor * attn = ggml_add_rel_pos(ctx0, KQ_scaled, rel_w, rel_h);
+            struct ggml_tensor * attn = ggml_add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h);
 
             struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn);
 
@@ -1306,37 +1347,11 @@ bool sam_encode_image(
 
     cur = ggml_conv_2d_sk_p0(ctx0, enc.neck_conv_0, cur);
 
-    // LayerNorm2d
-    {
-        // normalize along channel dimmension
-        // TODO: better implementation
-        cur = ggml_cont(ctx0, ggml_permute(ctx0,
-                    ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3))),
-                    2, 0, 1, 3));
-
-        cur = ggml_add(ctx0,
-                ggml_mul(ctx0,
-                    ggml_repeat(ctx0, ggml_reshape_3d(ctx0, enc.neck_norm_0_w, 1, 1, n_enc_out_chans), cur),
-                    cur),
-                ggml_repeat(ctx0, ggml_reshape_3d(ctx0, enc.neck_norm_0_b, 1, 1, n_enc_out_chans), cur));
-    }
+    cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_0_w, enc.neck_norm_0_b);
 
     cur = ggml_conv_2d_s1_ph(ctx0, enc.neck_conv_1, cur);
 
-    // LayerNorm2d
-    {
-        // normalize along channel dimmension
-        // TODO: better implementation
-        cur = ggml_cont(ctx0, ggml_permute(ctx0,
-                    ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3))),
-                    2, 0, 1, 3));
-
-        cur = ggml_add(ctx0,
-                ggml_mul(ctx0,
-                    ggml_repeat(ctx0, ggml_reshape_3d(ctx0, enc.neck_norm_1_w, 1, 1, n_enc_out_chans), cur),
-                    cur),
-                ggml_repeat(ctx0, ggml_reshape_3d(ctx0, enc.neck_norm_1_b, 1, 1, n_enc_out_chans), cur));
-    }
+    cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_1_w, enc.neck_norm_1_b);
 
     // TODO: avoid copy
     cur = ggml_cpy(ctx0, cur, state.embd_img);
@@ -1349,6 +1364,8 @@ bool sam_encode_image(
     ggml_build_forward_expand(&gf, cur);
     ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
+    //ggml_graph_print(&gf);
+
     ggml_free(ctx0);
     return true;
 }
@@ -1423,8 +1440,8 @@ bool sam_encode_prompt(
     // concat
     // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L192
     {
-        struct ggml_tensor * t_sin = ggml_map_unary_f32(ctx0, cur, ggml_sam_sin);
-        struct ggml_tensor * t_cos = ggml_map_unary_f32(ctx0, cur, ggml_sam_cos);
+        struct ggml_tensor * t_sin = ggml_map_custom1(ctx0, cur, ggml_sam_sin, GGML_N_TASKS_MAX, NULL);
+        struct ggml_tensor * t_cos = ggml_map_custom1(ctx0, cur, ggml_sam_cos, GGML_N_TASKS_MAX, NULL);
 
         cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, t_sin->ne[0] + t_cos->ne[0], cur->ne[1]);
 
@@ -1462,74 +1479,6 @@ bool sam_encode_prompt(
     // run the computation
     ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
-    // print
-    {
-        // auto print_t_f32 = [&](struct ggml_tensor * t) {
-        //     float * data = (float *)t->data;
-        //     printf("dims: %jd %jd %jd %jd f32\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]);
-        //     printf("data: ");
-        //     for (int i = 0; i < std::min((int) t->ne[0], 256); i++) {
-        //         printf("%f ", data[i]);
-        //     }
-        //     printf("\n");
-        //     //for (int y = 0; y < 64; ++y) {
-        //     //    for (int x = 0; x < 64; ++x) {
-        //     //        printf("%5.2f ", data[y*64 + x]);
-        //     //    }
-        //     //    printf("\n");
-        //     //}
-        //     //printf("\n");
-        //     // for (int y = 0; y < 64; ++y) {
-        //     //     for (int x = 0; x < 64; ++x) {
-        //     //         printf("%5.2f ", data[255*64*64 + y*64 + x]);
-        //     //     }
-        //     //     printf("\n");
-        //     // }
-        //     // printf("\n");
-        //     //for (int y = 0; y < 64; ++y) {
-        //     //    for (int x = 0; x < 64; ++x) {
-        //     //        printf("%5.2f ", data[(y*64 + x)*768 + 231]);
-        //     //    }
-        //     //    printf("\n");
-        //     //}
-        //     //printf("\n");
-        //     double sum = 0.0;
-        //     for (int i = 0; i < ggml_nelements(t); i++) {
-        //         sum += data[i];
-        //     }
-        //     printf("sum:  %f\n", sum);
-        // };
-
-        // auto print_t_f16 = [&](struct ggml_tensor * t) {
-        //     ggml_fp16_t * data = (ggml_fp16_t *)t->data;
-        //     printf("dims: %jd %jd %jd %jd f16\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]);
-        //     printf("data: ");
-        //     for (int i = 0; i < std::min((int) t->ne[0], 256); i++) {
-        //         printf("%f ", ggml_fp16_to_fp32(data[i]));
-        //     }
-        //     printf("\n");
-        //     for (int y = 0; y < 14; ++y) {
-        //         for (int x = 0; x < 14; ++x) {
-        //             printf("%7.4f ", ggml_fp16_to_fp32(data[(y*14 + x)*64 + 23]));
-        //         }
-        //         printf("\n");
-        //     }
-        //     printf("\n");
-        //     double sum = 0.0;
-        //     for (int i = 0; i < ggml_nelements(t); i++) {
-        //         sum += ggml_fp16_to_fp32(data[i]);
-        //     }
-        //     printf("sum:  %f\n", sum);
-        // };
-
-        // auto * t = ggml_get_tensor(ctx0, "check");
-        // if (t->type == GGML_TYPE_F32) {
-        //     print_t_f32(t);
-        // } else {
-        //     print_t_f16(t);
-        // }
-    }
-
     //printf("used_mem = %zu\n", ggml_used_mem(ctx0));
 
     ggml_free(ctx0);
@@ -1595,7 +1544,7 @@ struct ggml_tensor* sam_decode_mask_transformer_attn(
 
     struct ggml_tensor * KQV_merged = ggml_cont(ctx0, ggml_transpose(ctx0, KQV));
     KQV_merged = ggml_cont(ctx0, ggml_permute(ctx0, KQV_merged, 0, 2, 1, 3));
-    KQV_merged = ggml_cont(ctx0, ggml_reshape_3d(ctx0, KQV_merged, KQV_merged->ne[0]*KQV_merged->ne[1], KQV_merged->ne[2], KQV_merged->ne[3]));
+    KQV_merged = ggml_reshape_3d(ctx0, KQV_merged, KQV_merged->ne[0]*KQV_merged->ne[1], KQV_merged->ne[2], KQV_merged->ne[3]);
     KQV_merged = ggml_mul_mat(ctx0, attn.out_w, KQV_merged);
     KQV_merged = ggml_add(ctx0,
                 ggml_repeat(ctx0, attn.out_b, KQV_merged),
@@ -1859,21 +1808,7 @@ bool sam_decode_mask(
         // ConvTranspose2d
         keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_0_w, keys, 2);
         keys = ggml_add(ctx0, ggml_repeat(ctx0, dec.output_upscaling_0_b, keys), keys);
-
-        // LayerNorm2d
-        {
-            // normalize along channel dimmension
-            // TODO: better implementation
-            keys = ggml_cont(ctx0, ggml_permute(ctx0,
-                        ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, keys, 1, 2, 0, 3))),
-                        2, 0, 1, 3));
-
-            keys = ggml_add(ctx0,
-                    ggml_mul(ctx0,
-                        ggml_repeat(ctx0, ggml_reshape_3d(ctx0, dec.output_upscaling_1_w, 1, 1, n_img_embd), keys),
-                        keys),
-                    ggml_repeat(ctx0, ggml_reshape_3d(ctx0, dec.output_upscaling_1_b, 1, 1, n_img_embd), keys));
-        }
+        keys = sam_layer_norm_2d(ctx0, keys, n_img_embd, dec.output_upscaling_1_w, dec.output_upscaling_1_b);
 
         // GELU activation
         keys = ggml_gelu(ctx0, keys);
@@ -1898,7 +1833,7 @@ bool sam_decode_mask(
 
     struct ggml_tensor * masks = ggml_mul_mat(ctx0, hyper_in, upscaled_embedding);
     masks = ggml_cont(ctx0, ggml_transpose(ctx0, masks)); // TODO: Shouldn't be needed
-    masks = ggml_cont(ctx0, ggml_reshape_4d(ctx0, masks, keys->ne[0], keys->ne[1], masks->ne[1], keys->ne[3]));
+    masks = ggml_reshape_4d(ctx0, masks, keys->ne[0], keys->ne[1], masks->ne[1], keys->ne[3]);
 
     // Generate mask quality predictions
     // ref: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/mask_decoder.py#L146
@@ -1941,7 +1876,7 @@ bool sam_decode_mask(
 bool sam_write_masks(const sam_hparams& hparams, int nx, int ny, const sam_state & state) {
     if (state.low_res_masks->ne[2] == 0) return true;
     if (state.low_res_masks->ne[2] != state.iou_predictions->ne[0]) {
-        printf("Error: number of masks (%jd) does not match number of iou predictions (%jd)\n", state.low_res_masks->ne[2], state.iou_predictions->ne[0]);
+        printf("Error: number of masks (%d) does not match number of iou predictions (%d)\n", (int)state.low_res_masks->ne[2], (int)state.iou_predictions->ne[0]);
         return false;
     }
 
index 1ee5492fbc9a774103f1cc9f6b5ecc46923b8efd..8ffeeb81d47343d3cd2bf7389d678b1b277d3553 100644 (file)
@@ -1384,12 +1384,19 @@ extern "C" {
             int                   kh);
 
     // used in sam
+
     GGML_API struct ggml_tensor * ggml_add_rel_pos(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * pw,
             struct ggml_tensor  * ph);
 
+    GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * pw,
+            struct ggml_tensor  * ph);
+
     // custom operators
 
     typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
index 3d03d1ad52d49b10178f3d7277d9ae42047068ed..6d4d133066dccc9b59ba4fd1d8fe32337542d53e 100644 (file)
@@ -7329,33 +7329,30 @@ struct ggml_tensor * ggml_get_rel_pos(
 
 // ggml_add_rel_pos
 
-struct ggml_tensor * ggml_add_rel_pos(
+static struct ggml_tensor * ggml_add_rel_pos_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * pw,
-        struct ggml_tensor  * ph) {
-    GGML_ASSERT(pw->ne[0] == ph->ne[0]);
-    GGML_ASSERT(pw->ne[1] == ph->ne[1]);
-    GGML_ASSERT(pw->ne[2] == ph->ne[2]);
-    GGML_ASSERT(pw->ne[3] == ph->ne[3]);
-    GGML_ASSERT(pw->ne[3] == a->ne[2]);
-    GGML_ASSERT(pw->ne[0]*ph->ne[0] == a->ne[0]);
-    GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]);
+        struct ggml_tensor  * ph,
+        bool                  inplace) {
+    GGML_ASSERT(ggml_are_same_shape(pw, ph));
     GGML_ASSERT(ggml_is_contiguous(a));
     GGML_ASSERT(ggml_is_contiguous(pw));
     GGML_ASSERT(ggml_is_contiguous(ph));
-    GGML_ASSERT(pw->type == GGML_TYPE_F32);
     GGML_ASSERT(ph->type == GGML_TYPE_F32);
+    GGML_ASSERT(pw->type == GGML_TYPE_F32);
+    GGML_ASSERT(pw->ne[3] == a->ne[2]);
+    GGML_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]);
+    GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]);
 
     bool is_node = false;
 
-    if (a->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
+    if (!inplace && (a->grad || pw->grad || ph->grad)) {
         is_node = true;
     }
 
-    const int64_t ne[4] = { a->ne[0], a->ne[1], a->ne[2], 1, };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+    ggml_set_op_params_i32(result, 0, inplace ? 1 : 0);
 
     result->op   = GGML_OP_ADD_REL_POS;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -7366,6 +7363,23 @@ struct ggml_tensor * ggml_add_rel_pos(
     return result;
 }
 
+
+struct ggml_tensor * ggml_add_rel_pos(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * pw,
+        struct ggml_tensor  * ph) {
+    return ggml_add_rel_pos_impl(ctx, a, pw, ph, false);
+}
+
+struct ggml_tensor * ggml_add_rel_pos_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * pw,
+        struct ggml_tensor  * ph) {
+    return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
+}
+
 // gmml_unary
 
 static struct ggml_tensor * ggml_unary_impl(
@@ -10774,6 +10788,10 @@ static void ggml_compute_forward_mul_mat(
     GGML_ASSERT(nb1 <= nb2);
     GGML_ASSERT(nb2 <= nb3);
 
+    // broadcast factors
+    const int64_t r2 = ne12/ne02;
+    const int64_t r3 = ne13/ne03;
+
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
@@ -10793,11 +10811,6 @@ static void ggml_compute_forward_mul_mat(
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
-        // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
-        //       ref: https://github.com/ggerganov/ggml/pull/224
-        GGML_ASSERT(ne02 == ne12);
-        GGML_ASSERT(ne03 == ne13);
-
         if (params->ith != 0) {
             return;
         }
@@ -10810,12 +10823,16 @@ static void ggml_compute_forward_mul_mat(
             return;
         }
 
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                const void * x = (char *) src0->data + i03*nb03 + i02*nb02;
-                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+        for (int64_t i13 = 0; i13 < ne13; i13++) {
+            for (int64_t i12 = 0; i12 < ne12; i12++) {
+                // broadcast src0 into src1 across 2nd,3rd dimension
+                const int64_t i03 = i13/r3;
+                const int64_t i02 = i12/r2;
 
-                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+                const void  * x = (char *)            src0->data + i02*nb02 + i03*nb03;
+                const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
+
+                float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
 
                 if (type != GGML_TYPE_F32) {
                             float * const wdata    = params->wdata;
@@ -10823,7 +10840,7 @@ static void ggml_compute_forward_mul_mat(
 
                     size_t id = 0;
                     for (int64_t i01 = 0; i01 < ne01; ++i01) {
-                        to_float((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
+                        to_float((const char *) x + i01*nb01, wdata + id, ne00);
                         id += ne00;
                     }
 
@@ -10903,10 +10920,6 @@ static void ggml_compute_forward_mul_mat(
     assert(ne12 % ne02 == 0);
     assert(ne13 % ne03 == 0);
 
-    // broadcast factors
-    const int64_t r2 = ne12/ne02;
-    const int64_t r3 = ne13/ne03;
-
     // block-tiling attempt
     const int64_t blck_0 = 16;
     const int64_t blck_1 = 16;
@@ -14586,7 +14599,6 @@ static void ggml_compute_forward_get_rel_pos_f16(
     for (int64_t i2 = 0; i2 < ne2; ++i2) {
         for (int64_t i1 = 0; i1 < ne1; ++i1) {
             const int64_t pos = (w - i1 - 1) + i2;
-
             for (int64_t i0 = 0; i0 < ne0; ++i0) {
                 dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
             }
@@ -14618,26 +14630,25 @@ static void ggml_compute_forward_add_rel_pos_f32(
         const struct ggml_tensor * src1,
         const struct ggml_tensor * src2,
         struct ggml_tensor * dst) {
-    if (params->type == GGML_TASK_FINALIZE) {
+
+    const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
+    if (!inplace && params->type == GGML_TASK_INIT) {
+        memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
+        return;
+    }
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
+    int64_t t0 = ggml_perf_time_us();
+    UNUSED(t0);
 
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-    const int64_t ne2 = dst->ne[2];
+    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
 
-    float * src0_data = (float *) src0->data;
     float * src1_data = (float *) src1->data;
     float * src2_data = (float *) src2->data;
     float * dst_data  = (float *) dst->data;
 
-    if (params->type == GGML_TASK_INIT) {
-        memcpy(dst_data, src0_data, ne0*ne1*ne2*sizeof(float));
-        return;
-    }
-
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
     const int64_t ne12 = src1->ne[2];
@@ -14656,23 +14667,22 @@ static void ggml_compute_forward_add_rel_pos_f32(
     const int ip0 = dp*ith;
     const int ip1 = MIN(ip0 + dp, np);
 
+
     for (int64_t i13 = ip0; i13 < ip1; ++i13) {
         for (int64_t i12 = 0; i12 < ne12; ++i12) {
             for (int64_t i11 = 0; i11 < ne11; ++i11) {
+                const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;
                 for (int64_t i10 = 0; i10 < ne10; ++i10) {
-                    // add rel pos W (src1) to src0
-                    const int64_t i2 = i11;
-                    const int64_t i3 = i12;
-                    const int64_t i4 = i13;
-
-                    const int64_t jp  = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10 + i10;
+                    const int64_t jp0  = jp1 + i10;
+                    const float src1_e = src1_data[jp0];
+                    const float src2_e = src2_data[jp0];
 
-                    const int64_t jdw = i4*ne1*ne0 + i3*ne11*ne0 + i2*ne0 + i10;
-                    const int64_t jdh = i4*ne1*ne0 + i3*ne11*ne0 + i2*ne0 + i10*ne10;
+                    const int64_t jdh = jp0 * ne10;
+                    const int64_t jdw = jdh - (ne10 - 1) * i10;
 
                     for (int64_t j = 0; j < ne10; ++j) {
-                        dst_data[jdw + j*ne10] += src1_data[jp];
-                        dst_data[jdh + j     ] += src2_data[jp];
+                        dst_data[jdh + j     ] += src2_e;
+                        dst_data[jdw + j*ne10] += src1_e;
                     }
                 }
             }
index 3220bd72a00c46e7d5916e0ddfc87153fa6a21ca..33eda08f72b84fc889f5a4420119386807ea242f 100644 (file)
@@ -291,6 +291,14 @@ add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
 target_link_libraries(${TEST_TARGET} PRIVATE ggml)
 add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
 
+#
+# test-rel-pos
+
+set(TEST_TARGET test-rel-pos)
+add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
+target_link_libraries(${TEST_TARGET} PRIVATE ggml)
+add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
+
 #
 # test-svd0 (arm/x86)
 
diff --git a/tests/test-rel-pos.c b/tests/test-rel-pos.c
new file mode 100644 (file)
index 0000000..19960b4
--- /dev/null
@@ -0,0 +1,84 @@
+#include "ggml/ggml.h"
+
+#include <string.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+struct ggml_context* make_ctx(void) {
+    struct ggml_init_params params = {
+        .mem_size = 2 * 1024 * 1024,
+    };
+
+    return ggml_init(params);
+}
+
+void check_tensor(struct ggml_tensor * t, float * expected_t_d, int ne0, int ne1, int ne2) {
+    GGML_ASSERT(t->type == GGML_TYPE_F32);
+    GGML_ASSERT(t->ne[0] == ne0);
+    GGML_ASSERT(t->ne[1] == ne1);
+    GGML_ASSERT(t->ne[2] == ne2);
+    for (int i2 = 0; i2 < ne2; ++i2) {
+        for (int i1 = 0; i1 < ne1; ++i1) {
+            for (int i0 = 0; i0 < ne0; ++i0) {
+                float expected = *(expected_t_d + i2 * ne1 * ne0 + i1 * ne0 + i0);
+                float actual = ggml_get_data_f32(t)[i2 * ne1 * ne0 + i1 * ne0 + i0];
+                GGML_ASSERT(expected == actual);
+            }
+        }
+    }
+}
+
+int main(int argc, const char** argv) {
+    ggml_fp16_t buf_f16[1024];
+    for (int i = 0; i < 1024; ++i) {
+        buf_f16[i] = ggml_fp32_to_fp16((float)i);
+    }
+
+    float expected_out[4][9] = {
+        { 8.0, 9.0, 10.0, 9.0, 10.0, 11.0, 10.0, 11.0, 12.0 },
+        { 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 4.0, 5.0, 6.0 },
+        { 14.0, 15.0, 16.0, 15.0, 16.0, 17.0, 16.0, 17.0, 18.0 },
+        { 8.0, 9.0, 10.0, 9.0, 10.0, 11.0, 10.0, 11.0, 12.0 },
+    };
+
+    {
+        struct ggml_context * ctx = make_ctx();
+
+
+        struct ggml_tensor * t = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, 3, 3);
+        ggml_fp16_t* t_d = (ggml_fp16_t*)t->data;
+        memcpy(t_d, buf_f16, ggml_nbytes(t));
+
+        struct ggml_tensor * t_2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, 3, 3);
+        ggml_fp16_t* t_d_2 = (ggml_fp16_t*)t_2->data;
+        memcpy(t_d_2, buf_f16 + 1, ggml_nbytes(t_2));
+
+        struct ggml_tensor * rw = ggml_get_rel_pos(ctx, t, 2, 2);
+        struct ggml_tensor * rh = ggml_get_rel_pos(ctx, t_2, 2, 2);
+
+        struct ggml_tensor * rw_f32 = ggml_cpy(ctx, rw, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 3, 2, 2));
+        struct ggml_tensor * rh_f32 = ggml_cpy(ctx, rh, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 3, 2, 2));
+
+        struct ggml_tensor * in = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 9, 4);
+        struct ggml_tensor * out_inplace = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 9, 4);
+        float * in_d          = (float*)in->data;
+        float * out_inplace_d = (float*)out_inplace->data;
+        for (int i = 0; i < ggml_nelements(in); ++i) {
+            in_d[i]          = 1.f;
+            out_inplace_d[i] = 1.f;
+        }
+
+        struct ggml_tensor * out = ggml_add_rel_pos(ctx, in, rw_f32, rh_f32);
+        struct ggml_cgraph gf = ggml_build_forward(out);
+        ggml_graph_compute_with_ctx(ctx, &gf, 1);
+
+        out_inplace = ggml_add_rel_pos_inplace(ctx, out_inplace, rw_f32, rh_f32);
+        struct ggml_cgraph gf_2 = ggml_build_forward(out_inplace);
+        ggml_graph_compute_with_ctx(ctx, &gf_2, 1);
+
+        check_tensor(out, (float*)expected_out, 9, 4, 1);
+        check_tensor(out_inplace, (float*)expected_out, 9, 4, 1);
+    }
+
+    return 0;
+}