]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : generalize GGML_OP_CONCAT (llama/7563)
authorGeorgi Gerganov <redacted>
Tue, 28 May 2024 08:04:19 +0000 (11:04 +0300)
committerGeorgi Gerganov <redacted>
Wed, 29 May 2024 10:16:38 +0000 (13:16 +0300)
* ggml : generalize GGML_OP_CONCAT (WIP)

ggml-ci

* tests : add dim != 2 tests

* metal : generalize concat kernel

* tests : naming

* cuda : generalize concat kernel

ggml-ci

* sycl : add warning and assert

* ggml : fix op params handling

* metal : bugfix kernel

ggml-ci

* ggml : reimplement CPU and Metal

* cuda : add asserts

ggml-ci

* ggml : fix ptrs

ggml-ci

include/ggml/ggml.h
src/ggml-cuda/concat.cu
src/ggml-metal.m
src/ggml-metal.metal
src/ggml-sycl.cpp
src/ggml.c
tests/test-backend-ops.cpp

index 6706807b01ca14b45f1aa0d6ae8294c109efb926..3859895b6e72db5667b2801f4a72a27c482da138 100644 (file)
@@ -1007,12 +1007,13 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
-    // concat a and b on dim 2
+    // concat a and b along dim
     // used in stable-diffusion
     GGML_API struct ggml_tensor * ggml_concat(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
+            struct ggml_tensor  * b,
+            int                   dim);
 
     GGML_API struct ggml_tensor * ggml_abs(
             struct ggml_context * ctx,
index 2941d2f1770a86d7fc77b0e385de8461d526d417..fb9dee8f8cee55b49f06b4f76181189924c98750 100644 (file)
@@ -1,15 +1,68 @@
 #include "concat.cuh"
 
-static __global__ void concat_f32(const float * x,const float * y, float * dst, const int ne0, const int ne02) {
+static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
     int nidx = threadIdx.x + blockIdx.x * blockDim.x;
     if (nidx >= ne0) {
         return;
     }
-    // operation
+
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+
+    if (nidx < ne00) { // src0
+        int offset_src =
+            nidx +
+            blockIdx.y * ne00 +
+            blockIdx.z * ne00 * gridDim.y;
+        dst[offset_dst] = x[offset_src];
+    } else {
+        int offset_src =
+            (nidx - ne00) +
+            blockIdx.y * (ne0 - ne00) +
+            blockIdx.z * (ne0 - ne00) * gridDim.y;
+        dst[offset_dst] = y[offset_src];
+    }
+}
+
+static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) {
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+
+    if (blockIdx.y < ne01) { // src0
+        int offset_src =
+            nidx +
+            blockIdx.y * ne0 +
+            blockIdx.z * ne0 * ne01;
+        dst[offset_dst] = x[offset_src];
+    } else {
+        int offset_src =
+            nidx +
+            (blockIdx.y - ne01) * ne0 +
+            blockIdx.z * ne0 * (gridDim.y - ne01);
+        dst[offset_dst] = y[offset_src];
+    }
+}
+
+static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) {
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+
     int offset_dst =
         nidx +
         blockIdx.y * ne0 +
         blockIdx.z * ne0 * gridDim.y;
+
     if (blockIdx.z < ne02) { // src0
         int offset_src =
             nidx +
@@ -25,25 +78,53 @@ static __global__ void concat_f32(const float * x,const float * y, float * dst,
     }
 }
 
-static void concat_f32_cuda(const float * x, const float * y, float * dst, const int ne0, int ne1, int ne2, int ne02, cudaStream_t stream) {
+static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) {
     int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
     dim3 gridDim(num_blocks, ne1, ne2);
-    concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
+    if (dim == 0) {
+        concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00);
+        return;
+    }
+    if (dim == 1) {
+        concat_f32_dim1<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne01);
+        return;
+    }
+    concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
 }
 
 void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
+
     const float * src0_d = (const float *)src0->data;
     const float * src1_d = (const float *)src1->data;
+
     float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
+    const int32_t dim = ((int32_t *) dst->op_params)[0];
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
 
-    for (int i3 = 0; i3 < dst->ne[3]; i3++) {
-        concat_f32_cuda(src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4), dst_d + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], stream);
+    if (dim != 3) {
+        for (int i3 = 0; i3 < dst->ne[3]; i3++) {
+            concat_f32_cuda(
+                    src0_d + i3 * (src0->nb[3] / 4),
+                    src1_d + i3 * (src1->nb[3] / 4),
+                     dst_d + i3 * ( dst->nb[3] / 4),
+                    src0->ne[0], src0->ne[1], src0->ne[2],
+                     dst->ne[0],  dst->ne[1],  dst->ne[2], dim, stream);
+        }
+    } else {
+        const size_t size0 = ggml_nbytes(src0);
+        const size_t size1 = ggml_nbytes(src1);
+
+        CUDA_CHECK(cudaMemcpyAsync(dst_d,           src0_d, size0, cudaMemcpyDeviceToDevice, stream));
+        CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
     }
 }
index ff9ae55aada74247a20e14c330553ffdbc85cee5..4ba498e87f9d04064766502dd61ce9a2c5f17eef 100644 (file)
@@ -990,6 +990,8 @@ static enum ggml_status ggml_metal_graph_compute(
                     {
                         id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
 
+                        const int32_t dim = ((int32_t *) dst->op_params)[0];
+
                         [encoder setComputePipelineState:pipeline];
                         [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                         [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -1018,6 +1020,7 @@ static enum ggml_status ggml_metal_graph_compute(
                         [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:24];
                         [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:25];
                         [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:26];
+                        [encoder setBytes:&dim  length:sizeof(dim)  atIndex:27];
 
                         const int nth = MIN(1024, ne0);
 
index 174086b5b62939763b4afc1873d6d523914f3d49..b16f2b7e0c74f13456632c906f52d9e62a0d7fb4 100644 (file)
@@ -3366,31 +3366,30 @@ kernel void kernel_concat(
     constant  uint64_t & nb1,
     constant  uint64_t & nb2,
     constant  uint64_t & nb3,
+    constant   int32_t & dim,
     uint3 tgpig[[threadgroup_position_in_grid]],
     uint3 tpitg[[thread_position_in_threadgroup]],
     uint3   ntg[[threads_per_threadgroup]]) {
 
-    const int64_t i03 = tgpig.z;
-    const int64_t i02 = tgpig.y;
-    const int64_t i01 = tgpig.x;
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
 
-    const int64_t i13 = i03 % ne13;
-    const int64_t i12 = i02 % ne12;
-    const int64_t i11 = i01 % ne11;
+    int64_t o[4] = {0, 0, 0, 0};
+    o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
 
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
-    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
-    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + tpitg.x*nb0;
+    device const float * x;
 
     for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        if (i02 < ne02) {
-            ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
-            src0_ptr += ntg.x*nb00;
+        if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+            x = (device const float *)(src0 + (i3       )*nb03 + (i2       )*nb02 + (i1       )*nb01 + (i0       )*nb00);
         } else {
-            ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
-            src1_ptr += ntg.x*nb10;
+            x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
         }
-        dst_ptr += ntg.x*nb0;
+
+        device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+        *y = *x;
     }
 }
 
index 8839f775d5b880c507ef5d8efa96ac20497d2eb7..d5384b2e065a863d4cffede3a0a362987cec8eb9 100644 (file)
@@ -13512,6 +13512,10 @@ inline void ggml_sycl_op_concat(const ggml_tensor *src0,
                                 const float *src0_dd, const float *src1_dd,
                                 float *dst_dd,
                                 const dpct::queue_ptr &main_stream) {
+#pragma message("TODO: generalize concat kernel for dim != 2")
+#pragma message("      https://github.com/ggerganov/llama.cpp/pull/7563")
+    int dim = dst->op_params[0];
+    GGML_ASSERT(dim != 2);
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
index 535a36800bc4ab1d08d2b57d2405db755b97f65c..3ef9842e04d67abf5bfe6747ee9df80a052531ee 100644 (file)
@@ -4882,10 +4882,21 @@ struct ggml_tensor * ggml_repeat_back(
 // ggml_concat
 
 struct ggml_tensor * ggml_concat(
-    struct ggml_context* ctx,
-    struct ggml_tensor* a,
-    struct ggml_tensor* b) {
-    GGML_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]);
+    struct ggml_context * ctx,
+    struct ggml_tensor * a,
+    struct ggml_tensor * b,
+    int dim) {
+    GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
+
+    int64_t ne[GGML_MAX_DIMS];
+    for (int d = 0; d < GGML_MAX_DIMS; ++d) {
+        if (d == dim) {
+            ne[d] = a->ne[d] + b->ne[d];
+            continue;
+        }
+        GGML_ASSERT(a->ne[d] == b->ne[d]);
+        ne[d] = a->ne[d];
+    }
 
     bool is_node = false;
 
@@ -4893,7 +4904,9 @@ struct ggml_tensor * ggml_concat(
         is_node = true;
     }
 
-    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]);
+    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
+
+    ggml_set_op_params_i32(result, 0, dim);
 
     result->op = GGML_OP_CONCAT;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5013,6 +5026,7 @@ struct ggml_tensor * ggml_leaky_relu(
     }
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
     ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
 
     result->op   = GGML_OP_LEAKY_RELU;
@@ -10977,26 +10991,29 @@ static void ggml_compute_forward_concat_f32(
     GGML_ASSERT(nb00 == sizeof(float));
     GGML_ASSERT(nb10 == sizeof(float));
 
+    const int32_t dim = ggml_get_op_params_i32(dst, 0);
+
+    GGML_ASSERT(dim >= 0 && dim < 4);
+
+    int64_t o[4] = {0, 0, 0, 0};
+    o[dim] = src0->ne[dim];
+
+    const float * x;
+
+    // TODO: smarter multi-theading
     for (int i3 = 0; i3 < ne3; i3++) {
         for (int i2 = ith; i2 < ne2; i2 += nth) {
-            if (i2 < ne02) { // src0
-                for (int i1 = 0; i1 < ne1; i1++) {
-                    for (int i0 = 0; i0 < ne0; i0++) {
-                        const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03);
-
-                        float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
-                        *y = *x;
-                    }
-                }
-            } // src1
-            else {
-                for (int i1 = 0; i1 < ne1; i1++) {
-                    for (int i0 = 0; i0 < ne0; i0++) {
-                        const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13);
-
-                        float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
-                        *y = *x;
+            for (int i1 = 0; i1 < ne1; i1++) {
+                for (int i0 = 0; i0 < ne0; i0++) {
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        x = (const float *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
+                    } else {
+                        x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
                     }
+
+                    float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
+
+                    *y = *x;
                 }
             }
         }
@@ -11004,7 +11021,7 @@ static void ggml_compute_forward_concat_f32(
 }
 
 static void ggml_compute_forward_concat(
-    const struct ggml_compute_params* params,
+    const struct ggml_compute_params * params,
     struct ggml_tensor* dst) {
 
     const struct ggml_tensor * src0 = dst->src[0];
index de74585da29dd40905f4b2cc9a0aec176e5c5358..b200ccccd51b00075bd687834262f8b106b596d9 100644 (file)
@@ -1259,22 +1259,26 @@ struct test_im2col : public test_case {
 // GGML_OP_CONCAT
 struct test_concat : public test_case {
     const ggml_type type;
-    const std::array<int64_t, 4> ne;
-    const int64_t b_ne2;
+    const std::array<int64_t, 4> ne_a;
+    const int64_t ne_b_d;
+    const int dim;
 
     std::string vars() override {
-        return VARS_TO_STR3(type, ne, b_ne2);
+        return VARS_TO_STR4(type, ne_a, ne_b_d, dim);
     }
 
     test_concat(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {10, 10, 10, 10},
-            int64_t b_ne2 = 10)
-        : type(type), ne(ne), b_ne2(b_ne2) {}
+            std::array<int64_t, 4> ne_a = {10, 10, 10, 10},
+            int64_t ne_b_d = 10,
+            int dim = 2)
+        : type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
-        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_tensor * b = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], b_ne2, ne[3]);
-        ggml_tensor * out = ggml_concat(ctx, a, b);
+        auto ne_b = ne_a;
+        ne_b[dim] = ne_b_d;
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
+        ggml_tensor * out = ggml_concat(ctx, a, b, dim);
         return out;
     }
 };
@@ -2211,8 +2215,10 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         }
     }
 
-    test_cases.emplace_back(new test_concat(GGML_TYPE_F32));
-    test_cases.emplace_back(new test_concat(GGML_TYPE_I32));
+    for (int dim : { 0, 1, 2, 3, }) {
+        test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim));
+        test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim));
+    }
 
     for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));