]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA & CPU: support F32 kernel type for `CONV_TRANSPOSE_2D` (#17094)
authorYihao Wang <redacted>
Thu, 26 Mar 2026 02:19:14 +0000 (19:19 -0700)
committerGitHub <redacted>
Thu, 26 Mar 2026 02:19:14 +0000 (10:19 +0800)
* Refactor CUDA 2D transpose implementation to support multiple kernel types and improve parameter handling

- Introduced a `conv2d_transpose_params` struct for better parameter management.
- Updated `conv2d_transpose_kernel` to be templated for different kernel types (float and half).
- Modified `ggml_cuda_conv_2d_transpose_p0` to handle both F16 and F32 kernel types.
- Enhanced test cases to validate functionality for both kernel types.

* Refactor test cases for 2D convolution transpose to support dynamic kernel types

- Updated `test_conv_transpose_2d` structure to improve parameter handling by reordering constructor arguments.
- Enhanced test case generation to iterate over kernel types, allowing for flexible testing of different configurations.
- Removed hardcoded kernel type instances in favor of a loop for better maintainability and scalability.

* Refactor ggml_compute_forward_conv_transpose_2d to support both F16 and F32 tensor types.

* Refactor conv2d transpose kernel to use a template for kernel type, enhancing flexibility for different data types.
Update test cases to include both F16 and F32 tensor types for comprehensive coverage.

* Update ggml/src/ggml-cuda/conv2d-transpose.cu

Co-authored-by: Aman Gupta <redacted>
* Update ggml/src/ggml-cpu/ggml-cpu.c

Co-authored-by: Aman Gupta <redacted>
* Refactor conv2d transpose implementation by removing the conv2d_transpose_params struct and dispatching with direct kernel launch.

* Enhance cpu conv2d transpose implementation by introducing a templated kernel type for improved flexibility with F16 and F32 data types.

---------

Co-authored-by: Aman Gupta <redacted>
ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cpu/ops.cpp
ggml/src/ggml-cuda/conv2d-transpose.cu
ggml/src/ggml-cuda/conv2d-transpose.cuh
tests/test-backend-ops.cpp

index 8b323bd9b06829c8a29e174b7ef1cc2e656f372a..df17cc553006472a3c0f9c4d728e2c60a9f2c2bf 100644 (file)
@@ -2871,8 +2871,12 @@ struct ggml_cplan ggml_graph_plan(
                         const int64_t ne11 = node->src[1]->ne[1]; // H
                         const int64_t ne12 = node->src[1]->ne[2]; // Channels In
 
-                        cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
-                        cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
+                        GGML_ASSERT(node->src[0]->type == GGML_TYPE_F16 || node->src[0]->type == GGML_TYPE_F32);
+                        GGML_ASSERT(node->src[1]->type == GGML_TYPE_F32);
+
+                        cur += ggml_type_size(node->src[0]->type) * ne00 * ne01 * ne02 * ne03;
+                        cur += ggml_type_size(node->src[0]->type) * ne10 * ne11 * ne12;
+
                     } break;
                 case GGML_OP_TOP_K:
                     {
index 3f85e531daa6c2fec4d9460c222abf53b2cd71ac..d950972c83e9e2df33a9a8f57b959b4813755c3a 100644 (file)
@@ -6923,16 +6923,15 @@ void ggml_compute_forward_conv_3d(
     ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
 }
 
-// ggml_compute_forward_conv_transpose_2d
-
-void ggml_compute_forward_conv_transpose_2d(
-        const ggml_compute_params * params,
-              ggml_tensor * dst) {
+template <typename kernel_t>
+static void ggml_compute_forward_conv_transpose_2d_impl(
+    const ggml_compute_params * params,
+          ggml_tensor * dst) {
 
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
@@ -6943,7 +6942,7 @@ void ggml_compute_forward_conv_transpose_2d(
 
     const int nk = ne00*ne01*ne02*ne03;
 
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == ggml_type_size(src0->type));
     GGML_ASSERT(nb10 == sizeof(float));
 
     if (ith == 0) {
@@ -6951,12 +6950,12 @@ void ggml_compute_forward_conv_transpose_2d(
 
         // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
         {
-            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+            kernel_t * const wdata = (kernel_t *) params->wdata + 0;
 
             for (int64_t i03 = 0; i03 < ne03; i03++) {
                 for (int64_t i02 = 0; i02 < ne02; i02++) {
-                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
-                    ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
+                    const kernel_t * const src = (kernel_t *)((char *) src0->data + i03*nb03 + i02*nb02);
+                    kernel_t * dst_data = wdata + i02*ne01*ne00*ne03;
                     for (int64_t i01 = 0; i01 < ne01; i01++) {
                         for (int64_t i00 = 0; i00 < ne00; i00++) {
                             dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
@@ -6968,13 +6967,17 @@ void ggml_compute_forward_conv_transpose_2d(
 
         // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
         {
-            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
+            kernel_t * const wdata = (kernel_t *) params->wdata + nk;
             for (int i12 = 0; i12 < ne12; i12++) {
                 for (int i11 = 0; i11 < ne11; i11++) {
                     const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
-                    ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
+                    kernel_t * dst_data = wdata + i11*ne10*ne12;
                     for (int i10 = 0; i10 < ne10; i10++) {
-                        dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);
+                        if constexpr (std::is_same_v<kernel_t, ggml_fp16_t>) {
+                            dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);
+                        } else {
+                            dst_data[i10*ne12 + i12] = src[i10];
+                        }
                     }
                 }
             }
@@ -6996,21 +6999,27 @@ void ggml_compute_forward_conv_transpose_2d(
     const int ip0 = dp*ith;
     const int ip1 = MIN(ip0 + dp, np);
 
-    ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
-    ggml_fp16_t * const wdata_src = wdata + nk;
+    kernel_t * const wdata = (kernel_t *) params->wdata + 0;
+    kernel_t * const wdata_src = wdata + nk;
 
     for (int i2 = ip0; i2 < ip1; i2++) { // Cout
         float * dst_data = (float *)((char *) dst->data + i2*nb2);
-        ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
+        kernel_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
         for (int i11 = 0; i11 < ne11; i11++) {
             for (int i10 = 0; i10 < ne10; i10++) {
                 const int i1n = i11*ne10*ne12 + i10*ne12;
                 for (int i01 = 0; i01 < ne01; i01++) {
                     for (int i00 = 0; i00 < ne00; i00++) {
                         float v = 0;
-                        ggml_vec_dot_f16(ne03, &v, 0,
-                                wdata_src + i1n, 0,
-                                wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
+                        if constexpr (std::is_same_v<kernel_t, ggml_fp16_t>) {
+                            ggml_vec_dot_f16(ne03, &v, 0,
+                                    wdata_src + i1n, 0,
+                                    wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
+                        } else {
+                            ggml_vec_dot_f32(ne03, &v, 0,
+                                    wdata_src + i1n, 0,
+                                    wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
+                        }
                         dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
                     }
                 }
@@ -7019,6 +7028,28 @@ void ggml_compute_forward_conv_transpose_2d(
     }
 }
 
+void ggml_compute_forward_conv_transpose_2d(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_conv_transpose_2d_impl<ggml_fp16_t>(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_conv_transpose_2d_impl<float>(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_conv_2d_dw
 
 struct ggml_conv_2d_dw_params {
index 03224e404d32de6b074e1ddef601bbd7314a9e90..6cbd6f879e6f5e4a37160580a1a2d94fc4b9c060 100644 (file)
@@ -1,12 +1,20 @@
-#include <algorithm>
-
 #include "conv2d-transpose.cuh"
-#include "ggml.h"
-
-__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel,
-                                        float * __restrict__ output, const int in_w, const int in_h, const int out_w,
-                                        const int out_h, const int kernel_w, const int kernel_h, const int stride,
-                                        const int c_in, const int c_out, const int batches) {
+#include "convert.cuh"
+
+template <typename kernel_t>
+static __global__ void conv2d_transpose_kernel(const float * __restrict__ input,
+                                               const kernel_t * __restrict__ kernel,
+                                               float * __restrict__ output,
+                                               const int in_w,
+                                               const int in_h,
+                                               const int out_w,
+                                               const int out_h,
+                                               const int kernel_w,
+                                               const int kernel_h,
+                                               const int stride,
+                                               const int c_in,
+                                               const int c_out,
+                                               const int batches) {
     const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
 
     const int total_elements = out_w * out_h * c_out * batches;
@@ -26,24 +34,32 @@ __global__ void conv2d_transpose_kernel(const float * __restrict__ input, const
     for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) {
         for (int kh = 0; kh < kernel_h; ++kh) {
             int in_y = out_y_idx - kh;
-            if (in_y < 0 || in_y % stride) continue;
+            if (in_y < 0 || in_y % stride) {
+                continue;
+            }
             in_y /= stride;
-            if (in_y >= in_h) continue;
+            if (in_y >= in_h) {
+                continue;
+            }
 
             for (int kw = 0; kw < kernel_w; ++kw) {
                 int in_x = out_x_idx - kw;
-                if (in_x < 0 || in_x % stride) continue;
+                if (in_x < 0 || in_x % stride) {
+                    continue;
+                }
                 in_x /= stride;
-                if (in_x >= in_w) continue;
+                if (in_x >= in_w) {
+                    continue;
+                }
 
                 const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x;
                 const int kernel_idx =
                     (kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw;
 
-                float input_val = input[input_idx];
-                half  kern_val  = kernel[kernel_idx];
+                float    input_val = input[input_idx];
+                kernel_t kern_val  = kernel[kernel_idx];
 
-                accumulator += input_val * (float) kern_val;
+                accumulator += input_val * ggml_cuda_cast<float>(kern_val);
             }
         }
     }
@@ -56,11 +72,12 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor
     const ggml_tensor * kernel = dst->src[0];
     const ggml_tensor * input  = dst->src[1];
 
-    GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32);
+    GGML_ASSERT(input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
 
     const float * input_data  = (const float *) input->data;
     float *       output_data = (float *) dst->data;
-    const half * kernel_data = (const half *) kernel->data;
+    const void *  kernel_data = kernel->data;
 
     const int input_w      = input->ne[0];
     const int input_h      = input->ne[1];
@@ -82,10 +99,17 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor
     GGML_ASSERT(ggml_is_contiguous(kernel));
     GGML_ASSERT(ggml_is_contiguous(dst));
 
-    const int total  = (output_w * output_h * channels_out * batches);
+    const int total  = output_w * output_h * channels_out * batches;
     const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE;
 
-    conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
-        input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride,
-        channels_in, channels_out, batches);
+    if (kernel->type == GGML_TYPE_F16) {
+        conv2d_transpose_kernel<half><<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
+            input_data, (const half *) kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w,
+            kernel_h, stride, channels_in, channels_out, batches);
+
+    } else {
+        conv2d_transpose_kernel<float><<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
+            input_data, (const float *) kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w,
+            kernel_h, stride, channels_in, channels_out, batches);
+    }
 }
index c9430b2485021e913403f820983debb17f558e29..72889c5f0fa89ddc1771393cd287c3db8df91e76 100644 (file)
@@ -1,4 +1,5 @@
 #include "common.cuh"
 
 #define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256
+
 void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 82333a3a6add66492ea66203c12484d20636b80b..ac284db2b8a4c6d4feb578797d959904ebb0180e 100644 (file)
@@ -4823,28 +4823,33 @@ struct test_conv_transpose_1d : public test_case {
 
 // GGML_OP_CONV_TRANSPOSE_2D
 struct test_conv_transpose_2d : public test_case {
+    // Dimensions
     const std::array<int64_t, 4> ne_input;
     const std::array<int64_t, 4> ne_kernel;
     const int stride;
+    // Types
+    const ggml_type kernel_type;
 
     std::string vars() override {
-        return VARS_TO_STR3(ne_input, ne_kernel, stride);
+        return VARS_TO_STR4(kernel_type, ne_input, ne_kernel, stride);
     }
 
     double max_nmse_err() override {
         return 5e-4; // The default 1e-7 is too small for Vulkan.
     }
 
-    test_conv_transpose_2d(std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
-                           std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
-                           int stride = 1)
-        : ne_input(ne_input), ne_kernel(ne_kernel), stride(stride){}
+    test_conv_transpose_2d(
+        std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
+        std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
+        int stride = 1,
+        ggml_type kernel_type = GGML_TYPE_F16
+    ) : ne_input(ne_input), ne_kernel(ne_kernel), stride(stride), kernel_type(kernel_type) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
         ggml_set_name(input, "input");
 
-        ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne_kernel.data());
+        ggml_tensor * kernel = ggml_new_tensor(ctx, kernel_type, 4, ne_kernel.data());
         ggml_set_name(kernel, "kernel");
 
         ggml_tensor * out = ggml_conv_transpose_2d_p0(ctx, kernel, input, stride);
@@ -7704,9 +7709,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
     test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
 
-    test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1));
-    test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
-    test_cases.emplace_back(new test_conv_transpose_2d({129, 63, 35, 1}, {3, 3, 48, 35}, 1));
+    for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+        test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1, kernel_type));
+        test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2, kernel_type));
+        test_cases.emplace_back(new test_conv_transpose_2d({129, 63, 35, 1}, {3, 3, 48, 35}, 1, kernel_type));
+    }
 
     test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4,  500, 1, 1}));
     test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
@@ -8892,9 +8899,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
     test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
 
-    test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));
-    test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1));
-    test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2));
+    for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+        test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1, kernel_type));
+        test_cases.emplace_back(new test_conv_transpose_2d({16, 16, 16, 1}, {3, 3, 8, 16}, 1, kernel_type));
+        test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2, kernel_type));
+    }
 
     test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));