]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml: add `conv3d` op (llama/15182)
authorrmatif <redacted>
Fri, 22 Aug 2025 13:33:15 +0000 (15:33 +0200)
committerGeorgi Gerganov <redacted>
Fri, 5 Sep 2025 09:54:01 +0000 (12:54 +0300)
* add conv3d

* bump GGML_OP_COUNT

include/ggml.h
src/ggml-cpu/ggml-cpu.c
src/ggml-cpu/ops.cpp
src/ggml-cpu/ops.h
src/ggml.c
tests/test-backend-ops.cpp

index b8b82e11c86f5275a2ac6b634a416cf72266096e..7e9c3c8c7a096d05a2a55da717aa74e12cf3c00b 100644 (file)
@@ -512,6 +512,7 @@ extern "C" {
         GGML_OP_IM2COL,
         GGML_OP_IM2COL_BACK,
         GGML_OP_CONV_2D,
+        GGML_OP_CONV_3D,
         GGML_OP_CONV_2D_DW,
         GGML_OP_CONV_TRANSPOSE_2D,
         GGML_OP_POOL_1D,
@@ -1940,6 +1941,23 @@ extern "C" {
             int                   d0,  // dilation dimension 0
             int                   d1); // dilation dimension 1
 
+    GGML_API struct ggml_tensor * ggml_conv_3d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,   // kernel [KW, KH, KD, IC * OC]
+            struct ggml_tensor  * b,   // input  [W, H, D, C * N]
+            int                   s0,  // stride
+            int                   s1,
+            int                   s2,
+            int                   p0,  // padding
+            int                   p1,
+            int                   p2,
+            int                   d0,  // dilation
+            int                   d1,
+            int                   d2,
+            int                   n_channels,
+            int                   n_batch,
+            int                   n_channels_out);
+
     enum ggml_op_pool {
         GGML_OP_POOL_MAX,
         GGML_OP_POOL_AVG,
index f6bea3df34a0b212457db9c4bf7dfa5e9afc3cd5..0d5d3a3440aaf77336c5430160c990a5216fcc18 100644 (file)
@@ -1880,6 +1880,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_conv_2d(params, tensor);
             } break;
+        case GGML_OP_CONV_3D:
+            {
+                ggml_compute_forward_conv_3d(params, tensor);
+            } break;
         case GGML_OP_CONV_2D_DW:
             {
                 ggml_compute_forward_conv_2d_dw(params, tensor);
@@ -2252,6 +2256,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_IM2COL:
         case GGML_OP_IM2COL_BACK:
         case GGML_OP_CONV_2D:
+        case GGML_OP_CONV_3D:
         case GGML_OP_CONV_2D_DW:
         case GGML_OP_CONV_TRANSPOSE_1D:
         case GGML_OP_CONV_TRANSPOSE_2D:
@@ -2773,6 +2778,7 @@ struct ggml_cplan ggml_graph_plan(
                         }
                     } break;
                 case GGML_OP_CONV_2D:
+                case GGML_OP_CONV_3D:
                     {
                         cur = GGML_IM2COL_WORK_SIZE;
                     } break;
index b72a2556a5fc97850613d90fd3e1b8545c79e6aa..460367cca09e9f055973051f913b8d3bfb502c62 100644 (file)
@@ -7207,6 +7207,148 @@ void ggml_compute_forward_conv_2d(
     ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
 }
 
+// ggml_compute_forward_conv_3d
+
+static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
+                                              const ggml_tensor *         kernel,
+                                              const ggml_tensor *         src,
+                                              ggml_tensor *               dst,
+                                              ggml_type                   kernel_type) {
+
+    GGML_ASSERT(ggml_is_contiguous(kernel));
+    GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
+    GGML_ASSERT(kernel->type == kernel_type);
+
+    const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
+
+    const int32_t s0 = dst->op_params[0];
+    const int32_t s1 = dst->op_params[1];
+    const int32_t s2 = dst->op_params[2];
+    const int32_t p0 = dst->op_params[3];
+    const int32_t p1 = dst->op_params[4];
+    const int32_t p2 = dst->op_params[5];
+    const int32_t d0 = dst->op_params[6];
+    const int32_t d1 = dst->op_params[7];
+    const int32_t d2 = dst->op_params[8];
+    const int32_t c  = dst->op_params[9];
+    const int32_t n  = dst->op_params[10];
+    const int32_t oc = dst->op_params[11];
+
+    const int64_t src_w = src->ne[0];
+    const int64_t src_h = src->ne[1];
+    const int64_t src_d = src->ne[2];
+    const int64_t knl_w = kernel->ne[0];
+    const int64_t knl_h = kernel->ne[1];
+    const int64_t knl_d = kernel->ne[2];
+    const int64_t dst_w = dst->ne[0];
+    const int64_t dst_h = dst->ne[1];
+    const int64_t dst_d = dst->ne[2];
+
+    const float * src_data = (float *) src->data;
+    void  * knl_data       = kernel->data;
+    float * dst_data       = (float *) dst->data;
+
+    const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
+    const int64_t knl_n_total       = knl_n_per_channel * c;
+    const int64_t patch_total       = n * dst_w * dst_h * dst_d;
+
+    const int64_t space_per_patch   = knl_n_total * traits->type_size + oc * sizeof(float);
+    const int64_t batch_size        = params->wsize / space_per_patch;
+    const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
+    const int64_t batch_n           = (patch_total + patches_per_batch - 1) / patches_per_batch;
+
+    GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
+
+    void * tmp = params->wdata;
+
+    for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
+        const int64_t patch_start_batch = batch_i * patches_per_batch;
+        const int64_t patch_end_batch   = std::min(patch_start_batch + patches_per_batch, patch_total);
+        const int64_t patch_n_in_batch  = patch_end_batch - patch_start_batch;
+
+        const int64_t patch_per_thread  = (patch_n_in_batch + params->nth - 1) / params->nth;
+        const int64_t patch_start       = patch_start_batch + params->ith * patch_per_thread;
+        const int64_t patch_end         = std::min(patch_start + patch_per_thread, patch_end_batch);
+
+        for (int64_t p = patch_start; p < patch_end; ++p) {
+            const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
+            const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
+            const int64_t batch_idx  = p / (dst_w * dst_h * dst_d);
+            const int64_t dst_z      = p_in_batch / (dst_w * dst_h);
+            const int64_t dst_y      = p_in_depth / dst_w;
+            const int64_t dst_x      = p_in_depth % dst_w;
+
+            char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
+
+            for (int64_t ic = 0; ic < c; ++ic) {
+                for (int64_t kz = 0; kz < knl_d; ++kz) {
+                    for (int64_t ky = 0; ky < knl_h; ++ky) {
+                        for (int64_t kx = 0; kx < knl_w; ++kx) {
+                            const int64_t sz = dst_z * s2 + kz * d2 - p2;
+                            const int64_t sy = dst_y * s1 + ky * d1 - p1;
+                            const int64_t sx = dst_x * s0 + kx * d0 - p0;
+
+                            int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
+
+                            float src_val;
+                            if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
+                                src_val = 0.0f;
+                            } else {
+                                const int64_t cn_idx = batch_idx * c + ic;
+                                const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
+                                src_val = *src_ptr;
+                            }
+
+                            char * element_ptr = dst_row + dst_idx * traits->type_size;
+                            if (kernel_type == GGML_TYPE_F32) {
+                                *(float *)element_ptr = src_val;
+                            } else if (kernel_type == GGML_TYPE_F16) {
+                                *(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
+                            }
+                        }
+                    }
+                }
+            }
+        }
+
+        ggml_barrier(params->threadpool);
+
+        float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
+        ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
+
+        ggml_barrier(params->threadpool);
+
+        const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
+        const int64_t permute_start = params->ith * permute_per_thread;
+        const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
+
+        for (int64_t i = permute_start; i < permute_end; ++i) {
+            const int64_t p = patch_start_batch + i;
+            const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
+            const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
+            const int64_t batch_idx  = p / (dst_w * dst_h * dst_d);
+            const int64_t dst_z      = p_in_batch / (dst_w * dst_h);
+            const int64_t dst_y      = p_in_depth / dst_w;
+            const int64_t dst_x      = p_in_depth % dst_w;
+
+            for (int64_t ioc = 0; ioc < oc; ++ioc) {
+                const float value = gemm_output[i * oc + ioc];
+                const int64_t ocn_idx = batch_idx * oc + ioc;
+                float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
+                *dst_ptr = value;
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_conv_3d(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
+}
+
 // ggml_compute_forward_conv_transpose_2d
 
 void ggml_compute_forward_conv_transpose_2d(
index 82ea79eaa51ccd8cca78a9e0e28e3c8535a9c5fa..d0ea83843b544ae157e6951dc8288ca741cd949b 100644 (file)
@@ -70,6 +70,7 @@ void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * p
 void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
index a4417f1a17ef4123a0e517c6784514d6093765d2..d76ea58f789e2bc4883ff7317cf6f89fc869f092 100644 (file)
@@ -975,6 +975,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "IM2COL",
     "IM2COL_BACK",
     "CONV_2D",
+    "CONV_3D",
     "CONV_2D_DW",
     "CONV_TRANSPOSE_2D",
     "POOL_1D",
@@ -1017,7 +1018,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "GLU",
 };
 
-static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
+static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1077,6 +1078,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "im2col(x)",
     "im2col_back(x)",
     "conv_2d(x)",
+    "conv_3d(x)",
     "conv_2d_dw(x)",
     "conv_transpose_2d(x)",
     "pool_1d(x)",
@@ -1119,7 +1121,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "glu(x)",
 };
 
-static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
+static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -4480,6 +4482,56 @@ struct ggml_tensor * ggml_conv_2d_direct(
     return result;
 }
 
+// ggml_conv_3d
+
+struct ggml_tensor * ggml_conv_3d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   s0,
+        int                   s1,
+        int                   s2,
+        int                   p0,
+        int                   p1,
+        int                   p2,
+        int                   d0,
+        int                   d1,
+        int                   d2,
+        int                   c,
+        int                   n,
+        int                   oc) {
+
+    GGML_ASSERT(a->ne[3] == (int64_t) c * oc);
+    GGML_ASSERT(b->ne[3] == (int64_t) c * n);
+
+    int64_t ne[4];
+    ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
+    ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
+    ne[2] = ggml_calc_conv_output_size(b->ne[2], a->ne[2], s2, p2, d2);
+    ne[3] = (int64_t) oc * n;
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    ggml_set_op_params_i32(result, 0,  s0);
+    ggml_set_op_params_i32(result, 1,  s1);
+    ggml_set_op_params_i32(result, 2,  s2);
+    ggml_set_op_params_i32(result, 3,  p0);
+    ggml_set_op_params_i32(result, 4,  p1);
+    ggml_set_op_params_i32(result, 5,  p2);
+    ggml_set_op_params_i32(result, 6,  d0);
+    ggml_set_op_params_i32(result, 7,  d1);
+    ggml_set_op_params_i32(result, 8,  d2);
+    ggml_set_op_params_i32(result, 9,  c);
+    ggml_set_op_params_i32(result, 10, n);
+    ggml_set_op_params_i32(result, 11, oc);
+
+    result->op = GGML_OP_CONV_3D;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
 // ggml_conv_transpose_2d_p0
 
 static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
index e21e9042781e49f87c41cdaf8836532783cc1a20..a51527ca55c2359eb0dc3a14385a88e4e8e8cbff 100644 (file)
@@ -4091,6 +4091,75 @@ struct test_conv_2d_dw : public test_case {
     }
 };
 
+// GGML_OP_CONV_3D
+struct test_conv_3d : public test_case {
+    // Logical 5D dimensions
+    const int64_t N, IC, ID, IH, IW;
+    const int64_t OC, KD, KH, KW;
+    // Conv params
+    const int s0, s1, s2;
+    const int p0, p1, p2;
+    const int d0, d1, d2;
+    // Types
+    const ggml_type type_kernel;
+
+    std::string op_desc(ggml_tensor * t) override {
+        GGML_UNUSED(t);
+        return "CONV_3D";
+    }
+
+    std::string vars() override {
+        return VARS_TO_STR11(N, IC, ID, IH, IW, OC, KD, KH, KW, s0, s1) + "," +
+               VARS_TO_STR8(s2, p0, p1, p2, d0, d1, d2, type_kernel);
+    }
+
+    double max_nmse_err() override {
+        return 5e-4;
+    }
+
+    uint64_t op_flops(ggml_tensor * t) override {
+        GGML_UNUSED(t);
+        auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
+            return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
+        };
+        const int64_t OD = calc_conv_output_size(ID, KD, s2, p2, d2);
+        const int64_t OH = calc_conv_output_size(IH, KH, s1, p1, d1);
+        const int64_t OW = calc_conv_output_size(IW, KW, s0, p0, d0);
+
+        return (uint64_t)N * OC * OD * OH * OW * (2 * IC * KD * KH * KW - 1);
+    }
+
+    test_conv_3d(
+        int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW,
+        int64_t OC, int64_t KD, int64_t KH, int64_t KW,
+        int s0, int s1, int s2,
+        int p0, int p1, int p2,
+        int d0, int d1, int d2,
+        ggml_type type_kernel
+    ) : N(N), IC(IC), ID(ID), IH(IH), IW(IW),
+        OC(OC), KD(KD), KH(KH), KW(KW),
+        s0(s0), s1(s1), s2(s2),
+        p0(p0), p1(p1), p2(p2),
+        d0(d0), d1(d1), d2(d2),
+        type_kernel(type_kernel) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        // GGML input tensor is packed as [W, H, D, C*N]
+        const int64_t ne_input[] = {IW, IH, ID, IC * N};
+        ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input);
+        ggml_set_name(input, "input");
+
+        // GGML kernel tensor is packed as [KW, KH, KD, IC*OC]
+        const int64_t ne_kernel[] = {KW, KH, KD, IC * OC};
+        ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel);
+        ggml_set_name(kernel, "kernel");
+
+        ggml_tensor * out = ggml_conv_3d(ctx, kernel, input, s0, s1, s2, p0, p1, p2, d0, d1, d2, (int)IC, (int)N, (int)OC);
+        ggml_set_name(out, "out");
+        return out;
+    }
+};
+
 // GGML_OP_CONCAT
 struct test_concat : public test_case {
     const ggml_type type;
@@ -5528,6 +5597,61 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, false));
     test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, true));
 
+    // CONV_3D
+    auto calc_conv_output_size_3d = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
+        return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
+    };
+
+    for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+        for (int N : {1, 2}) {
+            for (int IC : {1, 3}) {
+                for (int OC : {1, 4}) {
+                    for (int s0 : {1, 2}) {
+                        for (int p1 : {0, 1}) {
+                            for (int d2 : {1, 2}) {
+                                int64_t IW = 20, IH = 22, ID = 18;
+                                int64_t KW = 3,  KH = 3,  KD = 3;
+                                int s1 = s0, s2 = s0;
+                                int p0 = p1, p2 = p1;
+                                int d0 = d2, d1 = d2;
+
+                                if (calc_conv_output_size_3d(IW, KW, s0, p0, d0) <= 0 ||
+                                    calc_conv_output_size_3d(IH, KH, s1, p1, d1) <= 0 ||
+                                    calc_conv_output_size_3d(ID, KD, s2, p2, d2) <= 0) {
+                                    continue;
+                                }
+                                test_cases.emplace_back(new test_conv_3d(
+                                    N, IC, ID, IH, IW,
+                                    OC, KD, KH, KW,
+                                    s0, s1, s2, p0, p1, p2, d0, d1, d2,
+                                    kernel_type));
+
+                                // Asymmetric kernel and params
+                                int64_t asym_KW = 5, asym_KH = 1, asym_KD = 3;
+                                int asym_s0 = 2, asym_s1 = 1, asym_s2 = 1;
+                                int asym_p0 = 2, asym_p1 = 0, asym_p2 = 1;
+                                int asym_d0 = 1, asym_d1 = 1, asym_d2 = 2;
+
+                                if (calc_conv_output_size_3d(IW, asym_KW, asym_s0, asym_p0, asym_d0) <= 0 ||
+                                    calc_conv_output_size_3d(IH, asym_KH, asym_s1, asym_p1, asym_d1) <= 0 ||
+                                    calc_conv_output_size_3d(ID, asym_KD, asym_s2, asym_p2, asym_d2) <= 0) {
+                                    continue;
+                                }
+                                test_cases.emplace_back(new test_conv_3d(
+                                    N, IC, ID, IH, IW,
+                                    OC, asym_KD, asym_KH, asym_KW,
+                                    asym_s0, asym_s1, asym_s2, asym_p0, asym_p1, asym_p2, asym_d0, asym_d1, asym_d2,
+                                    kernel_type));
+                            }
+                        }
+                    }
+                }
+            }
+        }
+        // Case with kernel size 1
+        test_cases.emplace_back(new test_conv_3d(1, 4, 8, 8, 8, 8, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, kernel_type));
+    }
+
     for(uint32_t Cout : {1, 9}){
         for(uint32_t Cin : {1, 7}){
             for(uint32_t K : {1, 3, 1337}){