]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cuda : non-cont concat support (llama/7610)
authorGeorgi Gerganov <redacted>
Wed, 29 May 2024 12:38:26 +0000 (15:38 +0300)
committerGeorgi Gerganov <redacted>
Sat, 15 Jun 2024 19:05:47 +0000 (22:05 +0300)
* tests : add non-cont concat tests

* cuda : non-cont concat support

ggml-ci

src/ggml-cuda/concat.cu
tests/test-backend-ops.cpp

index fb9dee8f8cee55b49f06b4f76181189924c98750..dac10ec36b0bd02dfbf922bb45a79470d4fb7465 100644 (file)
@@ -1,5 +1,6 @@
 #include "concat.cuh"
 
+// contiguous kernels
 static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
     int nidx = threadIdx.x + blockIdx.x * blockDim.x;
     if (nidx >= ne0) {
@@ -92,39 +93,104 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, int n
     concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
 }
 
+// non-contiguous kernel (slow)
+static __global__ void concat_f32_non_cont(
+        const char * src0,
+        const char * src1,
+              char * dst,
+           int64_t   ne00,
+           int64_t   ne01,
+           int64_t   ne02,
+           int64_t   ne03,
+          uint64_t   nb00,
+          uint64_t   nb01,
+          uint64_t   nb02,
+          uint64_t   nb03,
+           int64_t /*ne10*/,
+           int64_t /*ne11*/,
+           int64_t /*ne12*/,
+           int64_t /*ne13*/,
+          uint64_t   nb10,
+          uint64_t   nb11,
+          uint64_t   nb12,
+          uint64_t   nb13,
+           int64_t   ne0,
+           int64_t /*ne1*/,
+           int64_t /*ne2*/,
+           int64_t /*ne3*/,
+          uint64_t   nb0,
+          uint64_t   nb1,
+          uint64_t   nb2,
+          uint64_t   nb3,
+          int32_t   dim) {
+    const int64_t i3 = blockIdx.z;
+    const int64_t i2 = blockIdx.y;
+    const int64_t i1 = blockIdx.x;
+
+    int64_t o[4] = {0, 0, 0, 0};
+    o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
+
+    const float * x;
+
+    for (int i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
+        if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+            x = (const float *)(src0 + (i3       )*nb03 + (i2       )*nb02 + (i1       )*nb01 + (i0       )*nb00);
+        } else {
+            x = (const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
+        }
+
+        float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+        *y = *x;
+    }
+}
+
+
 void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
 
-    const float * src0_d = (const float *)src0->data;
-    const float * src1_d = (const float *)src1->data;
-
-    float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
     const int32_t dim = ((int32_t *) dst->op_params)[0];
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(src1));
-
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-
-    if (dim != 3) {
-        for (int i3 = 0; i3 < dst->ne[3]; i3++) {
-            concat_f32_cuda(
-                    src0_d + i3 * (src0->nb[3] / 4),
-                    src1_d + i3 * (src1->nb[3] / 4),
-                     dst_d + i3 * ( dst->nb[3] / 4),
-                    src0->ne[0], src0->ne[1], src0->ne[2],
-                     dst->ne[0],  dst->ne[1],  dst->ne[2], dim, stream);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+
+    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
+        const float * src0_d = (const float *)src0->data;
+        const float * src1_d = (const float *)src1->data;
+
+        float * dst_d = (float *)dst->data;
+
+        if (dim != 3) {
+            for (int i3 = 0; i3 < dst->ne[3]; i3++) {
+                concat_f32_cuda(
+                        src0_d + i3 * (src0->nb[3] / 4),
+                        src1_d + i3 * (src1->nb[3] / 4),
+                        dst_d + i3 * ( dst->nb[3] / 4),
+                        src0->ne[0], src0->ne[1], src0->ne[2],
+                        dst->ne[0],  dst->ne[1],  dst->ne[2], dim, stream);
+            }
+        } else {
+            const size_t size0 = ggml_nbytes(src0);
+            const size_t size1 = ggml_nbytes(src1);
+
+            CUDA_CHECK(cudaMemcpyAsync(dst_d,           src0_d, size0, cudaMemcpyDeviceToDevice, stream));
+            CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
         }
     } else {
-        const size_t size0 = ggml_nbytes(src0);
-        const size_t size1 = ggml_nbytes(src1);
-
-        CUDA_CHECK(cudaMemcpyAsync(dst_d,           src0_d, size0, cudaMemcpyDeviceToDevice, stream));
-        CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
+        dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
+        concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
+                (const char *)src0->data,
+                (const char *)src1->data,
+                (      char *)dst->data,
+                src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+                src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+                src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+                src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
+                dst->ne[0],  dst->ne[1],  dst->ne[2],  dst->ne[3],
+                dst->nb[0],  dst->nb[1],  dst->nb[2],  dst->nb[3], dim);
     }
 }
index b200ccccd51b00075bd687834262f8b106b596d9..5cde21c660514469f81113ee8145ec5243c71529 100644 (file)
@@ -1262,22 +1262,37 @@ struct test_concat : public test_case {
     const std::array<int64_t, 4> ne_a;
     const int64_t ne_b_d;
     const int dim;
+    const int v; // view (1 << 0: non-cont a, 1 << 1: non-cont b)
 
     std::string vars() override {
-        return VARS_TO_STR4(type, ne_a, ne_b_d, dim);
+        return VARS_TO_STR5(type, ne_a, ne_b_d, dim, v);
     }
 
     test_concat(ggml_type type = GGML_TYPE_F32,
             std::array<int64_t, 4> ne_a = {10, 10, 10, 10},
             int64_t ne_b_d = 10,
-            int dim = 2)
-        : type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim) {}
+            int dim = 2, int v = 0)
+        : type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim), v(v) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         auto ne_b = ne_a;
         ne_b[dim] = ne_b_d;
-        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
-        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
+        ggml_tensor * a;
+        if (v & 1) {
+            auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
+            a = ggml_new_tensor(ctx, type, 4, ne.data());
+            a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
+        } else {
+            a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+        }
+        ggml_tensor * b;
+        if (v & 2) {
+            auto ne = ne_b; ne[0] *= 3; ne[1] *= 2; ne[2] *= 4;
+            b = ggml_new_tensor(ctx, type, 4, ne.data());
+            b = ggml_view_4d(ctx, b, ne_b[0], ne_b[1], ne_b[2], ne_b[3], b->nb[1], b->nb[2], b->nb[3], 0);
+        } else {
+            b = ggml_new_tensor(ctx, type, 4, ne_b.data());
+        }
         ggml_tensor * out = ggml_concat(ctx, a, b, dim);
         return out;
     }
@@ -2215,9 +2230,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         }
     }
 
-    for (int dim : { 0, 1, 2, 3, }) {
-        test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim));
-        test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim));
+    for (int v : { 0, 1, 2, 3 }) {
+        for (int dim : { 0, 1, 2, 3, }) {
+            test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim, v));
+            test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim, v));
+        }
     }
 
     for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {