]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
feat: cuda implementation for `ggml_conv_transpose_1d` (ggml/854)
authorJohn Balis <redacted>
Tue, 2 Jul 2024 16:09:52 +0000 (11:09 -0500)
committerGeorgi Gerganov <redacted>
Mon, 8 Jul 2024 09:23:00 +0000 (12:23 +0300)
* conv transpose 1d passing test for 1d input and kernel

* working for different input and output channel counts, added test for variable stride

* initial draft appears to work with stride other than 1

* working with all old and new conv1d  tests

* added a test for large tensors

* removed use cuda hardcoding

* restored test-conv-transpose.c

* removed unused arugments, and fixed bug where test failure would cause subsequent tests to fail

* fixed accumulator bug

* added test to test-backend-ops

* fixed mistake

* addressed review

* fixed includes

* removed blank lines

* style and warning fixes

* return failure when test fails

* fix supports_op

---------

Co-authored-by: slaren <redacted>
ggml/src/ggml-cuda.cu
ggml/src/ggml-cuda/conv-transpose-1d.cu [new file with mode: 0644]
ggml/src/ggml-cuda/conv-transpose-1d.cuh [new file with mode: 0644]
tests/test-backend-ops.cpp

index 1c9ccc8a15e54e61540b3ffc63b4af5fd609a8c8..ed784ea1c20246d7c6854092a9507dd53a69b84c 100644 (file)
@@ -29,6 +29,7 @@
 #include "ggml-cuda/tsembd.cuh"
 #include "ggml-cuda/unary.cuh"
 #include "ggml-cuda/upscale.cuh"
+#include "ggml-cuda/conv-transpose-1d.cuh"
 
 #include <algorithm>
 #include <array>
@@ -2261,6 +2262,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_IM2COL:
             ggml_cuda_op_im2col(ctx, dst);
             break;
+        case GGML_OP_CONV_TRANSPOSE_1D:
+            ggml_cuda_op_conv_transpose_1d(ctx,dst);
+            break;
         case GGML_OP_POOL_2D:
             ggml_cuda_op_pool2d(ctx, dst);
             break;
@@ -2804,6 +2808,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                 ggml_type src0_type = op->src[0]->type;
                 return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
             } break;
+        case GGML_OP_CONV_TRANSPOSE_1D:
+            {
+                ggml_type src0_type = op->src[0]->type;
+                ggml_type src1_type = op->src[1]->type;
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
+                return false;
+            } break;
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
         case GGML_OP_VIEW:
diff --git a/ggml/src/ggml-cuda/conv-transpose-1d.cu b/ggml/src/ggml-cuda/conv-transpose-1d.cu
new file mode 100644 (file)
index 0000000..b1e94d6
--- /dev/null
@@ -0,0 +1,87 @@
+#include "conv-transpose-1d.cuh"
+
+static  __global__ void conv_transpose_1d_kernel(
+        const int s0, const int p0, const int d0, const int output_size,
+        const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
+        const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
+        const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
+        const float * src0, const float * src1,  float * dst) {
+    int global_index = threadIdx.x + blockIdx.x * blockDim.x;
+    if (global_index >= output_size) {
+        return;
+    }
+
+    int out_index = global_index / dst_ne0;
+
+    float accumulator = 0;
+
+    for (int c = 0; c < src0_ne2; c++) {
+        int idx = global_index % dst_ne0;
+
+        int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
+        int input_offset = src1_ne0 * c;
+
+        for (int i = 0; i < src1_ne0; i++) {
+            if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {
+                continue;
+            }
+            int weight_idx = idx - i*s0;
+
+            float kernel_weight = src0[kernel_offset + weight_idx];
+            float input_value =  src1[input_offset+i];
+
+            accumulator += kernel_weight * input_value;
+        }
+    }
+    dst[global_index] = accumulator;
+}
+
+static void conv_transpose_1d_f32_f32_cuda(
+        const int s0, const int p0, const int d0, const int output_size,
+        const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
+        const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
+        const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
+        const float * src0, const float * src1,  float * dst,
+        cudaStream_t stream) {
+
+    const int num_blocks = (output_size + CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE;
+    conv_transpose_1d_kernel<<<num_blocks,CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE, 0, stream>>>(
+        s0,p0,d0,output_size,
+        src0_ne0, src0_ne1,  src0_ne2, src0_ne3,
+        src1_ne0, src1_ne1,  src1_ne2, src1_ne3,
+        dst_ne0,  dst_ne1,   dst_ne2,  dst_ne3,
+        src0,src1, dst);
+}
+
+void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+
+    const ggml_tensor * src1 = dst->src[1];
+    const float * src1_d = (const float *)src1->data;
+
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+
+    const int32_t * opts = (const int32_t *)dst->op_params;
+
+    const int s0 = opts[0];
+    const int p0 = 0;//opts[3];
+    const int d0 = 1;//opts[4];
+
+    const int64_t kernel_size = ggml_nelements(src0);
+    const int64_t input_size = ggml_nelements(src1);
+    const int64_t output_size = ggml_nelements(dst);
+
+    conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size,
+        src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+        src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+        dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+        src0_d, src1_d, dst_d, stream);
+}
diff --git a/ggml/src/ggml-cuda/conv-transpose-1d.cuh b/ggml/src/ggml-cuda/conv-transpose-1d.cuh
new file mode 100644 (file)
index 0000000..6c2cf66
--- /dev/null
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE 256
+
+void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 2bb71ac03817fd1b18577227276ff6affb20a607..83d4a0eb153cf4c34860e8c10b2d7e6f946a7366 100644 (file)
@@ -1266,6 +1266,36 @@ struct test_pool2d : public test_case {
     }
 };
 
+// GGML_OP_CONV_TRANSPOSE_1D
+struct test_conv_transpose_1d : public test_case {
+    
+    const std::array<int64_t, 4> ne_input;
+    const std::array<int64_t, 4> ne_kernel;
+
+    // stride
+    const int s0;
+    // padding
+    const int p0;
+    // dilation
+    const int d0;
+
+    std::string vars() override {
+        return VARS_TO_STR5(ne_input, ne_kernel, s0, p0, d0);
+    }
+
+    test_conv_transpose_1d(std::array<int64_t, 4> ne_input = {197, 32, 1, 1}, // [input_width, input_height, input_channels, 1]
+                           std::array<int64_t, 4> ne_kernel = {16, 32, 32, 1}, // [kernel_width, kernel_height, input_channels, 1]
+                           int s0 = 1, int p0 = 0, int d0 = 1)
+        : ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), p0(p0), d0(d0) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
+        ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data());
+        ggml_tensor * out = ggml_conv_transpose_1d(ctx, kernel, input, s0, p0, d0);
+        return out;
+    }
+};
+
 // GGML_OP_IM2COL
 struct test_im2col : public test_case {
     const ggml_type type_input;
@@ -1279,7 +1309,7 @@ struct test_im2col : public test_case {
     // padding
     const int p0;
     const int p1;
-    // dilatation
+    // dilation
     const int d0;
     const int d1;
     // mode
@@ -2098,6 +2128,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
 
+    test_cases.emplace_back(new test_conv_transpose_1d());
+    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1));
+    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1));
+    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 1, 0, 1));
+    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 2, 0, 1));
+    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 1, 0, 1));
+    test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
+    test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
+
+
     test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
     test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1}));
     test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1}));