]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml: add ops for WAN video model (cuda && cpu) (#15669)
authorleejet <redacted>
Thu, 4 Sep 2025 08:38:49 +0000 (16:38 +0800)
committerGitHub <redacted>
Thu, 4 Sep 2025 08:38:49 +0000 (10:38 +0200)
* add conv3d support

* add ggml_pad_ext for cpu & cuda backend

* cuda/cpu: add im2col_3d support

* cuda: make im2col a little faster

* fix cuda pad/scale/im2col3d

* make im2col_3d faster

* gguf: support loading tensors which n_dims > GGML_MAX_DIMS

* fix cuda get_rows

* avoid ggml_conv_3d conflict

* correct GGML_OP_COUNT assertion

* avoid build failure

* avoid build failure on MacOS

* cuda: remove unnecessary MIN define

* fix cpu im2col_3d

* adjust the code style

* cuda: use simpler loop in get_rows

* add test_im2col_3d to test-backend-ops

* test-backend-ops.cpp: remove trailing whitespace

* cpu: im2col_3d support non continuous src

Co-authored-by: Jeff Bolz <redacted>
* fix test_im2col_3d

* remove unused variables

* cuda: get_rows: dfloat2 -> float2

* add test_pad_ext to test-backend-ops.cpp

* add gguf_init_from_file_ext impl

* Revert "gguf: support loading tensors which n_dims > GGML_MAX_DIMS"

This reverts commit d8377a0a37f314bd3713fe043b4333ad661610c1.

* Revert "add gguf_init_from_file_ext impl"

This reverts commit d9f1d13208c68ef83b3538201ac7f31614fb1994.

* update ggml_backend_vk_device_supports_op

* fix ggml_backend_vk_device_supports_op

* update other backend supports op for ggml_pad_ext

* metal/opencl/sycl/vulkan: fix GGML_OP_PAD check in supports_op

---------

Co-authored-by: Jeff Bolz <redacted>
17 files changed:
ggml/include/ggml.h
ggml/src/ggml-cann/aclnn_ops.cpp
ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cpu/ops.cpp
ggml/src/ggml-cpu/ops.h
ggml/src/ggml-cuda/getrows.cu
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/im2col.cu
ggml/src/ggml-cuda/im2col.cuh
ggml/src/ggml-cuda/pad.cu
ggml/src/ggml-cuda/scale.cu
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-opencl/ggml-opencl.cpp
ggml/src/ggml-sycl/ggml-sycl.cpp
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml.c
tests/test-backend-ops.cpp

index 7e9c3c8c7a096d05a2a55da717aa74e12cf3c00b..c01b98ac78f5a823665fe4fad824a3effa94acac 100644 (file)
@@ -511,6 +511,7 @@ extern "C" {
         GGML_OP_CONV_TRANSPOSE_1D,
         GGML_OP_IM2COL,
         GGML_OP_IM2COL_BACK,
+        GGML_OP_IM2COL_3D,
         GGML_OP_CONV_2D,
         GGML_OP_CONV_3D,
         GGML_OP_CONV_2D_DW,
@@ -1870,6 +1871,41 @@ extern "C" {
             int                   d0,  // dilation dimension 0
             int                   d1); // dilation dimension 1
 
+    GGML_API struct ggml_tensor * ggml_im2col_3d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int64_t               IC,
+            int                   s0, // stride width
+            int                   s1, // stride height
+            int                   s2, // stride depth
+            int                   p0, // padding width
+            int                   p1, // padding height
+            int                   p2, // padding depth
+            int                   d0, // dilation width
+            int                   d1, // dilation height
+            int                   d2, // dilation depth
+            enum ggml_type        dst_type);
+
+    // a: [OC*IC, KD, KH, KW]
+    // b: [N*IC, ID, IH, IW]
+    // result: [N*OC, OD, OH, OW]
+    GGML_API struct ggml_tensor * ggml_conv_3d(
+                struct ggml_context * ctx,
+                struct ggml_tensor  * a,
+                struct ggml_tensor  * b,
+                int64_t               IC,
+                int                   s0, // stride width
+                int                   s1, // stride height
+                int                   s2, // stride depth
+                int                   p0, // padding width
+                int                   p1, // padding height
+                int                   p2, // padding depth
+                int                   d0, // dilation width
+                int                   d1, // dilation height
+                int                   d2  // dilation depth
+        );
+
     // kernel size is a->ne[0] x a->ne[1]
     // stride is equal to kernel size
     // padding is zero
@@ -1941,7 +1977,7 @@ extern "C" {
             int                   d0,  // dilation dimension 0
             int                   d1); // dilation dimension 1
 
-    GGML_API struct ggml_tensor * ggml_conv_3d(
+    GGML_API struct ggml_tensor * ggml_conv_3d_direct(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,   // kernel [KW, KH, KD, IC * OC]
             struct ggml_tensor  * b,   // input  [W, H, D, C * N]
@@ -2048,6 +2084,19 @@ extern "C" {
             int                  p2,
             int                  p3);
 
+    GGML_API struct ggml_tensor * ggml_pad_ext(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                  lp0,
+            int                  rp0,
+            int                  lp1,
+            int                  rp1,
+            int                  lp2,
+            int                  rp2,
+            int                  lp3,
+            int                  rp3
+            );
+
     // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
     GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
             struct ggml_context * ctx,
index 2d81fbd5a185b11704ece8997d956b9670e902f1..ac2e2e1adf3bb8175f58baa51a945956013bbc8a 100755 (executable)
@@ -589,9 +589,16 @@ void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     // the position of elements in the array means which dirction to padding,
     // each position means: [dim0.front, dim0.behind, dim1.front, dim1.behind,
     //                       dim2.front, dim2.behind, dim3.front, dim3.behind]
-    int64_t paddings[] = {
-        0, dst->ne[0] - src->ne[0], 0, dst->ne[1] - src->ne[1],
-        0, dst->ne[2] - src->ne[2], 0, dst->ne[3] - src->ne[3]};
+    const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
+    const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
+    const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
+    const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
+    const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
+    const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
+    const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
+    const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
+
+    int64_t paddings[] = {lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3};
     aclnn_pad(ctx, acl_src, acl_dst, paddings);
     ggml_cann_release_resources(ctx, acl_src, acl_dst);
 }
index 78ec189d4c64671a62351339070fd4a60f5ec1c3..0d35d9333e3f5913179b6df8b97a643ee91aab37 100644 (file)
@@ -1876,6 +1876,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_im2col_back_f32(params, tensor);
             } break;
+        case GGML_OP_IM2COL_3D:
+            {
+                ggml_compute_forward_im2col_3d(params, tensor);
+            } break;
         case GGML_OP_CONV_2D:
             {
                 ggml_compute_forward_conv_2d(params, tensor);
@@ -2255,6 +2259,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             } break;
         case GGML_OP_IM2COL:
         case GGML_OP_IM2COL_BACK:
+        case GGML_OP_IM2COL_3D:
         case GGML_OP_CONV_2D:
         case GGML_OP_CONV_3D:
         case GGML_OP_CONV_2D_DW:
index 8c1f7948855ac5c7076f781da9a36fa0a233c207..0bb767e01aed3ba66aec91fcc691a24290f3d08a 100644 (file)
@@ -7027,6 +7027,209 @@ void ggml_compute_forward_im2col_back_f32(
     }
 }
 
+
+// ggml_compute_forward_im2col_3d_f16
+// src0: kernel [OC*IC, KD, KH, KW]
+// src1: image [N*IC, ID, IH, IW]
+// dst:  result [N*OD, OH, OW, IC * KD * KH * KW]
+static void ggml_compute_forward_im2col_3d_f16(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F16);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
+    const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
+    const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
+    const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
+    const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
+    const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
+    const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
+    const int32_t IC = ((const int32_t *)(dst->op_params))[9];
+
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t N  = ne13 / IC;
+    const int64_t ID = ne12;
+    const int64_t IH = ne11;
+    const int64_t IW = ne10;
+
+    const int64_t OC = ne03 / IC;
+    GGML_UNUSED(OC);
+    const int64_t KD = ne02;
+    const int64_t KH = ne01;
+    const int64_t KW = ne00;
+
+    const int64_t OD = ne3 / N;
+    const int64_t OH = ne2;
+    const int64_t OW = ne1;
+    const int64_t OH_OW = OH*OW;
+    const int64_t KD_KH_KW = KD*KH*KW;
+    const int64_t KH_KW = KH*KW;
+    const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
+
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
+    {
+        ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
+
+        for (int64_t in = 0; in < N; in++) {
+            for (int64_t iod = 0; iod < OD; iod++) {
+                for (int64_t ioh = 0; ioh < OH; ioh++) {
+                    for (int64_t iow = 0; iow < OW; iow++) {
+                        for (int64_t iic = ith; iic < IC; iic += nth) {
+
+                            // micro kernel
+                            ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
+                            const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
+
+                            for (int64_t ikd = 0; ikd < KD; ikd++) {
+                                for (int64_t ikh = 0; ikh < KH; ikh++) {
+                                    for (int64_t ikw = 0; ikw < KW; ikw++) {
+                                        const int64_t iiw = iow*s0 + ikw*d0 - p0;
+                                        const int64_t iih = ioh*s1 + ikh*d1 - p1;
+                                        const int64_t iid = iod*s2 + ikd*d2 - p2;
+
+                                        if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
+                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
+                                        } else {
+                                            const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
+                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s);
+                                        }
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+// ggml_compute_forward_im2col_3d_f32
+// src0: kernel [OC*IC, KD, KH, KW]
+// src1: image [N*IC, ID, IH, IW]
+// dst:  result [N*OD, OH, OW, IC * KD * KH * KW]
+static void ggml_compute_forward_im2col_3d_f32(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
+    const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
+    const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
+    const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
+    const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
+    const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
+    const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
+    const int32_t IC = ((const int32_t *)(dst->op_params))[9];
+
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t N  = ne13 / IC;
+    const int64_t ID = ne12;
+    const int64_t IH = ne11;
+    const int64_t IW = ne10;
+
+    const int64_t OC = ne03 / IC;
+    GGML_UNUSED(OC);
+    const int64_t KD = ne02;
+    const int64_t KH = ne01;
+    const int64_t KW = ne00;
+
+    const int64_t OD = ne3 / N;
+    const int64_t OH = ne2;
+    const int64_t OW = ne1;
+
+    const int64_t OH_OW = OH*OW;
+    const int64_t KD_KH_KW = KD*KH*KW;
+    const int64_t KH_KW = KH*KW;
+    const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
+
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
+    {
+        float * const wdata = (float *) dst->data;
+
+        for (int64_t in = 0; in < N; in++) {
+            for (int64_t iod = 0; iod < OD; iod++) {
+                for (int64_t ioh = 0; ioh < OH; ioh++) {
+                    for (int64_t iow = 0; iow < OW; iow++) {
+                        for (int64_t iic = ith; iic < IC; iic += nth) {
+
+                            // micro kernel
+                            float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
+                            const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
+
+                            for (int64_t ikd = 0; ikd < KD; ikd++) {
+                                for (int64_t ikh = 0; ikh < KH; ikh++) {
+                                    for (int64_t ikw = 0; ikw < KW; ikw++) {
+                                        const int64_t iiw = iow*s0 + ikw*d0 - p0;
+                                        const int64_t iih = ioh*s1 + ikh*d1 - p1;
+                                        const int64_t iid = iod*s2 + ikd*d2 - p2;
+
+                                        if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
+                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
+                                        } else {
+                                            const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
+                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
+                                        }
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+
+void ggml_compute_forward_im2col_3d(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+    switch (dst->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_im2col_3d_f16(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_im2col_3d_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
                               void * a, void * b, float * c) {
     const ggml_type_traits * traits = ggml_get_type_traits(type);
@@ -8014,6 +8217,15 @@ static void ggml_compute_forward_pad_f32(
     GGML_TENSOR_UNARY_OP_LOCALS
 
     float * dst_ptr = (float *) dst->data;
+    const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
+    const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
+    const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
+    const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
+    const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
+    const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
+    const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
+    const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
+
 
     // TODO: optimize
 
@@ -8022,10 +8234,12 @@ static void ggml_compute_forward_pad_f32(
             for (int64_t i0 = 0; i0 < ne0; ++i0) {
                 for (int64_t i3 = 0; i3 < ne3; ++i3) {
                     const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
-
-                    const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-
-                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                    if ((i0 >= lp0 && i0 < ne0 - rp0) \
+                         && (i1 >= lp1 && i1 < ne1 - rp1) \
+                         && (i2 >= lp2 && i2 < ne2 - rp2) \
+                         && (i3 >= lp3 && i3 < ne3 - rp3)) {
+                        const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
+                        const float * src_ptr = (const float *)((char *) src0->data + src_idx);
                         dst_ptr[dst_idx] = *src_ptr;
                     } else {
                         dst_ptr[dst_idx] = 0;
index d0ea83843b544ae157e6951dc8288ca741cd949b..9824a03b458336b9f40806a39fd402007ec77e0d 100644 (file)
@@ -69,6 +69,7 @@ void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struc
 void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_im2col_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
index 3ec0e957ab5542bf866226e447d4d9eeaf275f95..83d02474f5d4851cb099ff6cdfa045e4cb4cd54d 100644 (file)
@@ -2,6 +2,8 @@
 #include "dequantize.cuh"
 #include "convert.cuh"
 
+#define MAX_GRIDDIM_Y 65535
+
 template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
 static __global__ void k_get_rows(
         const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
@@ -11,32 +13,29 @@ static __global__ void k_get_rows(
         /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
         const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
 
-    // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
-    const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2;
-    const int i10 =  blockIdx.x;
-    const int i11 =  blockIdx.z / ne12;
-    const int i12 =  blockIdx.z % ne12;
+    for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
+        // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
+        const int i10 =  blockIdx.x;
+        const int i11 =  blockIdx.z / ne12;
+        const int i12 =  blockIdx.z % ne12;
 
-    if (i00 >= ne00) {
-        return;
-    }
+        const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
 
-    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+        dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+        const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
 
-    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
-    const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
+        const int ib   =  i00/qk;      // block index
+        const int iqs  = (i00%qk)/qr;  // quant index
+        const int iybs = i00 - i00%qk; // dst block start index
+        const int y_offset = qr == 1 ? 1 : qk/2;
 
-    const int ib   =  i00/qk;      // block index
-    const int iqs  = (i00%qk)/qr;  // quant index
-    const int iybs = i00 - i00%qk; // dst block start index
-    const int y_offset = qr == 1 ? 1 : qk/2;
+        // dequantize
+        float2 v;
+        dequantize_kernel(src0_row, ib, iqs, v);
 
-    // dequantize
-    float2 v;
-    dequantize_kernel(src0_row, ib, iqs, v);
-
-    dst_row[iybs + iqs + 0]        = ggml_cuda_cast<dst_t>(v.x);
-    dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
+        dst_row[iybs + iqs + 0]        = ggml_cuda_cast<dst_t>(v.x);
+        dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
+    }
 }
 
 template<typename src0_t, typename dst_t>
@@ -48,22 +47,23 @@ static __global__ void k_get_rows_float(
         /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
         const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
 
-    // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
-    const int i00 = blockIdx.y * blockDim.x + threadIdx.x;
-    const int i10 = blockIdx.x;
-    const int i11 = blockIdx.z / ne12;
-    const int i12 = blockIdx.z % ne12;
+    for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
+        // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
+        const int i10 = blockIdx.x;
+        const int i11 = blockIdx.z / ne12;
+        const int i12 = blockIdx.z % ne12;
 
-    if (i00 >= ne00) {
-        return;
-    }
+        if (i00 >= ne00) {
+            return;
+        }
 
-    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+        const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
 
-    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
-    const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
+        dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+        const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
 
-    dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
+        dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
+    }
 }
 
 template<typename grad_t, typename dst_t>
@@ -98,7 +98,7 @@ static void get_rows_cuda_q(
         cudaStream_t stream) {
     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
     const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
-    const dim3 block_nums(ne10, block_num_y, ne11*ne12);
+    const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12);
 
     // strides in elements
     // const size_t s0 = nb0 / sizeof(dst_t);
@@ -131,7 +131,7 @@ static void get_rows_cuda_float(
         cudaStream_t stream) {
     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
     const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
-    const dim3 block_nums(ne10, block_num_y, ne11*ne12);
+    const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12);
 
     // strides in elements
     // const size_t s0 = nb0 / sizeof(dst_t);
index e06f95f0819ed71768a39d3c26fe2721fc10e569..0c01eb6fa8359148f1b842950f60834273001952 100644 (file)
@@ -2452,6 +2452,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_IM2COL:
             ggml_cuda_op_im2col(ctx, dst);
             break;
+        case GGML_OP_IM2COL_3D:
+            ggml_cuda_op_im2col_3d(ctx, dst);
+            break;
         case GGML_OP_CONV_2D:
             ggml_cuda_op_conv2d(ctx, dst);
             break;
@@ -3559,6 +3562,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
         }
         case GGML_OP_IM2COL:
+        case GGML_OP_IM2COL_3D:
         case GGML_OP_CONV_2D:
         case GGML_OP_CONV_2D_DW:
         case GGML_OP_CONV_TRANSPOSE_2D:
index 16bb9bec97d25afb1b659ea1249fb16d40d34e07..7737d6a5d5230ad112c825441ec59a31e6c63a3e 100644 (file)
@@ -112,3 +112,132 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
         im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
     }
 }
+
+// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
+template <typename T>
+static  __global__ void im2col_3d_kernel(
+        const float * src, T * dst,
+        int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
+        int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
+        int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW,
+        int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW,
+        int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH,
+        int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) {
+    const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
+    if (i >= IC_KD_KH_KW) {
+        return;
+    }
+
+    const int64_t iic = i / KD_KH_KW;
+    const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW;
+    const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
+    const int64_t ikw = i % KW;
+
+    const int64_t  iow = blockIdx.y;
+    for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) {
+        const int64_t in  = iz / OD_OH;
+        const int64_t iod = (iz - in*OD_OH) / OH;
+        const int64_t ioh = iz % OH;
+
+        const int64_t iiw = iow * s0 + ikw * d0 - p0;
+        const int64_t iih = ioh * s1 + ikh * d1 - p1;
+        const int64_t iid = iod * s2 + ikd * d2 - p2;
+
+        const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
+
+        if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
+            dst[offset_dst] = 0.0f;
+        } else {
+            const int64_t offset_src = in*IC_ID_IH_IW + iic*ID_IH_IW + iid*IH_IW + iih*IW + iiw;
+            dst[offset_dst] = src[offset_src];
+        }
+    }
+}
+
+// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
+template <typename T>
+static void im2col_3d_cuda(const float * src, T* dst,
+    int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
+    int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
+    int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
+    const int64_t OH_OW = OH*OW;
+    const int64_t KD_KH_KW = KD*KH*KW;
+    const int64_t ID_IH_IW = ID*IH*IW;
+    const int64_t KH_KW = KH*KW;
+    const int64_t IH_IW = IH*IW;
+    const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
+    const int64_t OW_KD_KH_KW = OW*KD*KH*KW;
+    const int64_t N_OD_OH = N*OD*OH;
+    const int64_t OD_OH = OD*OH;
+    const int64_t IC_ID_IH_IW = IC*ID*IH*IW;
+    const int64_t OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
+    const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
+    const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
+    const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
+    dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z));
+    im2col_3d_kernel<<<block_nums, MIN(IC_KD_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
+                                                                                           OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,
+                                                                                           IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,
+                                                                                           OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH,
+                                                                                           s0, s1, s2, p0, p1, p2, d0, d1, d2);
+}
+
+static void im2col_3d_cuda_f16(const float * src, half * dst,
+    int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
+    int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
+    int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
+
+    im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
+}
+
+static void im2col_3d_cuda_f32(const float * src, float * dst,
+    int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
+    int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
+    int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
+
+    im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
+}
+
+void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const float * src1_d = (const float *)src1->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
+    const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
+    const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
+    const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
+    const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
+    const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
+    const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
+    const int32_t IC = ((const int32_t *)(dst->op_params))[9];
+
+    const int64_t N  = ne13 / IC;
+    const int64_t ID = ne12;
+    const int64_t IH = ne11;
+    const int64_t IW = ne10;
+
+    const int64_t OC = ne03 / IC;
+    const int64_t KD = ne02;
+    const int64_t KH = ne01;
+    const int64_t KW = ne00;
+
+    const int64_t OD = ne3 / N;
+    const int64_t OH = ne2;
+    const int64_t OW = ne1;
+
+    if(dst->type == GGML_TYPE_F16) {
+        im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
+    } else {
+        im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
+    }
+}
index 1ce8fae4d9a3d4745a0e0450a9dae89e974757ae..2da1223d6345b413a1431878a9f5cb7b21cab1bb 100644 (file)
@@ -3,3 +3,4 @@
 #define CUDA_IM2COL_BLOCK_SIZE 256
 
 void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 77432b04689be128b5be18e9f0d74c47b7cd8fd1..29aef33c1a4b85c48d0a257daffcb8562eba4c88 100644 (file)
@@ -1,36 +1,50 @@
 #include "pad.cuh"
 
-static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
-    // blockIdx.z: idx of ne2*ne3, aka ne02*ne03
-    // blockIdx.y: idx of ne1
-    // blockIDx.x: idx of ne0 / BLOCK_SIZE
-    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
-    if (nidx >= ne0) {
+static __global__ void pad_f32(const float * src, float * dst,
+                               const int lp0, const int rp0, const int lp1, const int rp1,
+                               const int lp2, const int rp2, const int lp3, const int rp3,
+                               const int ne0, const int ne1, const int ne2, const int ne3) {
+    // blockIdx.z: i3*ne2+i2
+    // blockIdx.y: i1
+    // blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE
+    // gridDim.y:  ne1
+    int i0 = threadIdx.x + blockIdx.x * blockDim.x;
+    int i1 = blockIdx.y;
+    int i2 = blockIdx.z % ne2;
+    int i3 = blockIdx.z / ne2;
+    if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
         return;
     }
 
     // operation
-    int offset_dst =
-        nidx +
-        blockIdx.y * ne0 +
-        blockIdx.z * ne0 * gridDim.y;
-    if (nidx < ne00 && blockIdx.y < (unsigned)ne01 && blockIdx.z < (unsigned)(ne02*ne03)) {
-        int offset_src =
-            nidx +
-            blockIdx.y * ne00 +
-            blockIdx.z * ne00 * ne01;
-        dst[offset_dst] = x[offset_src];
+    const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
+    if ((i0 >= lp0 && i0 < ne0 - rp0) &&
+        (i1 >= lp1 && i1 < ne1 - rp1) &&
+        (i2 >= lp2 && i2 < ne2 - rp2) &&
+        (i3 >= lp3 && i3 < ne3 - rp3)) {
+        const int64_t i00 = i0 - lp0;
+        const int64_t i01 = i1 - lp1;
+        const int64_t i02 = i2 - lp2;
+        const int64_t i03 = i3 - lp3;
+        const int64_t ne02 = ne2 - lp2 - rp2;
+        const int64_t ne01 = ne1 - lp1 - rp1;
+        const int64_t ne00 = ne0 - lp0 - rp0;
+
+        const int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00;
+
+        dst[dst_idx] = src[src_idx];
     } else {
-        dst[offset_dst] = 0.0f;
+        dst[dst_idx] = 0.0f;
     }
 }
 
-static void pad_f32_cuda(const float * x, float * dst,
-    const int ne00, const int ne01, const int ne02, const int ne03,
+static void pad_f32_cuda(const float * src, float * dst,
+    const int lp0, const int rp0, const int lp1, const int rp1,
+    const int lp2, const int rp2, const int lp3, const int rp3,
     const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
     int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
     dim3 gridDim(num_blocks, ne1, ne2*ne3);
-    pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
+    pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3);
 }
 
 void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -41,9 +55,18 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
-    GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
+    const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
+    const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
+    const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
+    const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
+    const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
+    const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
+    const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
 
     pad_f32_cuda(src0_d, dst_d,
-        src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
-        dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
+                 lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
+                 dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
 }
index 2ee9e588992f46cce94710d684b7199fed441907..0ddeff6a1755f11a9c8c442a07443339ccec1337 100644 (file)
@@ -1,18 +1,19 @@
 #include "scale.cuh"
 
-static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+#define MAX_GRIDDIM_X 0x7FFFFFFF
 
-    if (i >= k) {
-        return;
-    }
+static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) {
+    int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
+    int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x;
 
-    dst[i] = scale * x[i] + bias;
+    for (int64_t i = tid; i < nelements; i += stride) {
+        dst[i] = scale * x[i] + bias;
+    }
 }
 
-static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
-    scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, k);
+static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) {
+    const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
+    scale_f32<<<MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, nelements);
 }
 
 void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
index 3d16a1dcd461e10fbedf18490c6245c5ed6ee483..9b4006d987c3b9558139964ac855f6a51ec66642 100644 (file)
@@ -1886,7 +1886,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_UPSCALE:
             return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
         case GGML_OP_POOL_2D:
+            return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_PAD:
+            return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
+                   (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
         case GGML_OP_PAD_REFLECT_1D:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_ARGSORT:
index fc54c90f7b3bd1349427539a7b0f1769ff5a1838..727163b7fdf950e010b975e4b449c3e6d5af723a 100644 (file)
@@ -2701,7 +2701,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
             return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
         case GGML_OP_PAD:
             return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
-                   op->src[0]->ne[3] == 1 && op->ne[3] == 1;
+                   op->src[0]->ne[3] == 1 && op->ne[3] == 1 &&
+                   (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
+                   (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
         case GGML_OP_UPSCALE:
             return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
         case GGML_OP_CONV_2D:
index 18ff4e0b0c7cf1792057ce6c8bcce659cd99b425..877fbf7e8626d11f1d5b3fa596be42151aadfcde 100644 (file)
@@ -4398,7 +4398,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
             return ggml_is_contiguous(op->src[0]);
         case GGML_OP_POOL_2D:
         case GGML_OP_ACC:
+            return true;
         case GGML_OP_PAD:
+            return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
+                   (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
         case GGML_OP_LEAKY_RELU:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_RWKV_WKV6:
index 3cd5af1cdc4562f0a3209e4169f0b8b7201fb8a3..cd1c66ba7b47650e085fae3a4786ca93f61260be 100644 (file)
@@ -12076,7 +12076,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_ACC:
         case GGML_OP_CONCAT:
         case GGML_OP_SCALE:
+            return true;
         case GGML_OP_PAD:
+            return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
+                   (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
         case GGML_OP_ROLL:
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_SOFT_MAX:
index d76ea58f789e2bc4883ff7317cf6f89fc869f092..f35c337952ec398e3a90a4cf14c27922c9403615 100644 (file)
@@ -974,6 +974,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CONV_TRANSPOSE_1D",
     "IM2COL",
     "IM2COL_BACK",
+    "IM2COL_3D",
     "CONV_2D",
     "CONV_3D",
     "CONV_2D_DW",
@@ -1018,7 +1019,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "GLU",
 };
 
-static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
+static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1077,6 +1078,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "conv_transpose_1d(x)",
     "im2col(x)",
     "im2col_back(x)",
+    "im2col_3d(x)",
     "conv_2d(x)",
     "conv_3d(x)",
     "conv_2d_dw(x)",
@@ -1121,7 +1123,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "glu(x)",
 };
 
-static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89");
+static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -4361,6 +4363,91 @@ struct ggml_tensor * ggml_conv_2d(
     return result;
 }
 
+// a: [OC*IC, KD, KH, KW]
+// b: [N*IC, ID, IH, IW]
+// result: [N*OD, OH, OW, IC * KD * KH * KW]
+struct ggml_tensor * ggml_im2col_3d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int64_t               IC,
+        int                   s0, // stride width
+        int                   s1, // stride height
+        int                   s2, // stride depth
+        int                   p0, // padding width
+        int                   p1, // padding height
+        int                   p2, // padding depth
+        int                   d0, // dilation width
+        int                   d1, // dilation height
+        int                   d2, // dilation depth
+        enum ggml_type        dst_type) {
+    const int64_t N = b->ne[3] / IC;
+    const int64_t ID = b->ne[2];
+    const int64_t IH = b->ne[1];
+    const int64_t IW = b->ne[0];
+
+    const int64_t OC = a->ne[3] / IC;
+    UNUSED(OC);
+    const int64_t KD = a->ne[2];
+    const int64_t KH = a->ne[1];
+    const int64_t KW = a->ne[0];
+    const int64_t OD = ggml_calc_conv_output_size(ID, KD, s2, p2, d2);
+    const int64_t OH = ggml_calc_conv_output_size(IH, KH, s1, p1, d1);
+    const int64_t OW = ggml_calc_conv_output_size(IW, KW, s0, p0, d0);
+
+    GGML_ASSERT((OD > 0)  && "b too small compared to a");
+    GGML_ASSERT((OH > 0)  && "b too small compared to a");
+    GGML_ASSERT((OW > 0)  && "b too small compared to a");
+
+
+    const int64_t ne[4] = {KW*KH*KD*IC, OW, OH, OD*N};
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
+    int32_t params[] = { s0, s1, s2, p0, p1, p2, d0, d1, d2, (int32_t)IC};
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_IM2COL_3D;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+// a: [OC*IC, KD, KH, KW]
+// b: [N*IC, ID, IH, IW]
+// result: [N*OC, OD, OH, OW]
+struct ggml_tensor * ggml_conv_3d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int64_t               IC,
+        int                   s0, // stride width
+        int                   s1, // stride height
+        int                   s2, // stride depth
+        int                   p0, // padding width
+        int                   p1, // padding height
+        int                   p2, // padding depth
+        int                   d0, // dilation width
+        int                   d1, // dilation height
+        int                   d2  // dilation depth
+        ) {
+    struct ggml_tensor * im2col = ggml_im2col_3d(ctx, a, b, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, a->type); // [N*OD, OH, OW, IC * KD * KH * KW]
+
+    int64_t OC = a->ne[3] / IC;
+    int64_t N = b->ne[3] / IC;
+    struct ggml_tensor * result =
+        ggml_mul_mat(ctx,
+                ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N*OD, OH, OW, IC * KD * KH * KW] => [N*OD*OH*OW, IC * KD * KH * KW]
+                ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2] * IC), OC));                          // [OC*IC, KD, KH, KW] => [OC, IC * KD * KH * KW]
+
+    int64_t OD = im2col->ne[3] / N;
+    result = ggml_reshape_4d(ctx, result, im2col->ne[1]*im2col->ne[2], OD, N, OC); // [OC, N*OD*OH*OW] => [OC, N, OD, OH*OW]
+    result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OD, OH*OW]
+    result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], OD, OC * N); // [N*OC, OD, OH, OW]
+
+    return result;
+}
+
 // ggml_conv_2d_sk_p0
 
 struct ggml_tensor * ggml_conv_2d_sk_p0(
@@ -4482,9 +4569,9 @@ struct ggml_tensor * ggml_conv_2d_direct(
     return result;
 }
 
-// ggml_conv_3d
+// ggml_conv_3d_direct
 
-struct ggml_tensor * ggml_conv_3d(
+struct ggml_tensor * ggml_conv_3d_direct(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
@@ -4710,11 +4797,36 @@ struct ggml_tensor * ggml_pad(
         int                   p1,
         int                   p2,
         int                   p3) {
+    return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
+}
+
+struct ggml_tensor * ggml_pad_ext(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                  lp0,
+            int                  rp0,
+            int                  lp1,
+            int                  rp1,
+            int                  lp2,
+            int                  rp2,
+            int                  lp3,
+            int                  rp3
+            ) {
     struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
-            a->ne[0] + p0,
-            a->ne[1] + p1,
-            a->ne[2] + p2,
-            a->ne[3] + p3);
+            a->ne[0] + lp0 + rp0,
+            a->ne[1] + lp1 + rp1,
+            a->ne[2] + lp2 + rp2,
+            a->ne[3] + lp3 + rp3);
+
+    ggml_set_op_params_i32(result, 0, lp0);
+    ggml_set_op_params_i32(result, 1, rp0);
+    ggml_set_op_params_i32(result, 2, lp1);
+    ggml_set_op_params_i32(result, 3, rp1);
+    ggml_set_op_params_i32(result, 4, lp2);
+    ggml_set_op_params_i32(result, 5, rp2);
+    ggml_set_op_params_i32(result, 6, lp3);
+    ggml_set_op_params_i32(result, 7, rp3);
+
 
     result->op     = GGML_OP_PAD;
     result->src[0] = a;
index 3a58621094d17655047ec165ee6da9521eb689a6..89b812f1abb4fd1ae0a66625cc235329cf9b639d 100644 (file)
@@ -297,6 +297,8 @@ static std::string var_to_str(ggml_scale_mode mode) {
 #define VARS_TO_STR11(a, b, c, d, e, f, g, h, i, j, k) VAR_TO_STR(a) + "," + VARS_TO_STR10(b, c, d, e, f, g, h, i, j, k)
 #define VARS_TO_STR12(a, b, c, d, e, f, g, h, i, j, k, l) VAR_TO_STR(a) + "," + VARS_TO_STR11(b, c, d, e, f, g, h, i, j, k, l)
 #define VARS_TO_STR13(a, b, c, d, e, f, g, h, i, j, k, l, m) VAR_TO_STR(a) + "," + VARS_TO_STR12(b, c, d, e, f, g, h, i, j, k, l, m)
+#define VARS_TO_STR14(a, b, c, d, e, f, g, h, i, j, k, l, m, n) VAR_TO_STR(a) + "," + VARS_TO_STR13(b, c, d, e, f, g, h, i, j, k, l, m, n)
+#define VARS_TO_STR15(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o) VAR_TO_STR(a) + "," + VARS_TO_STR14(b, c, d, e, f, g, h, i, j, k, l, m, n, o)
 
 #ifdef GGML_USE_SYCL
 static bool inline _isinf(float f) {
@@ -4023,6 +4025,56 @@ struct test_im2col : public test_case {
     }
 };
 
+// GGML_OP_IM2COL_3D
+struct test_im2col_3d : public test_case {
+    const ggml_type type_input;
+    const ggml_type type_kernel;
+    const ggml_type dst_type;
+    const std::array<int64_t, 4> ne_input;
+    const std::array<int64_t, 4> ne_kernel;
+    // stride
+    const int s0;
+    const int s1;
+    const int s2;
+    // padding
+    const int p0;
+    const int p1;
+    const int p2;
+    // dilation
+    const int d0;
+    const int d1;
+    const int d2;
+
+    const int64_t IC;
+
+    std::string vars() override {
+        return VARS_TO_STR15(type_input, type_kernel, dst_type, ne_input, ne_kernel, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2);
+    }
+
+    test_im2col_3d(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, ggml_type dst_type = GGML_TYPE_F32,
+                std::array<int64_t, 4> ne_input = {10, 10, 10, 9}, // [OC*IC, KD, KH, KW]
+                std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [N*IC, ID, IH, IW]
+                int64_t IC = 3,
+                int s0 = 1, int s1 = 1, int s2 = 1,
+                int p0 = 1, int p1 = 1, int p2 = 1,
+                int d0 = 1, int d1 = 1, int d2 = 1)
+        : type_input(type_input), type_kernel(type_kernel), dst_type(dst_type), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), IC(IC) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
+        ggml_set_param(input);
+        ggml_set_name(input, "input");
+
+        ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
+        ggml_set_name(kernel, "kernel");
+
+        ggml_tensor * out = ggml_im2col_3d(ctx, kernel, input, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, dst_type);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
 // CONV_2D
 struct test_conv_2d : public test_case {
     const std::array<int64_t, 4> ne_input;
@@ -4221,7 +4273,7 @@ struct test_conv_3d : public test_case {
         ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel);
         ggml_set_name(kernel, "kernel");
 
-        ggml_tensor * out = ggml_conv_3d(ctx, kernel, input, s0, s1, s2, p0, p1, p2, d0, d1, d2, (int)IC, (int)N, (int)OC);
+        ggml_tensor * out = ggml_conv_3d_direct(ctx, kernel, input, s0, s1, s2, p0, p1, p2, d0, d1, d2, (int)IC, (int)N, (int)OC);
         ggml_set_name(out, "out");
         return out;
     }
@@ -4640,6 +4692,39 @@ struct test_pad : public test_case {
     }
 };
 
+struct test_pad_ext : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne_a;
+    const int lp0;
+    const int rp0;
+    const int lp1;
+    const int rp1;
+    const int lp2;
+    const int rp2;
+    const int lp3;
+    const int rp3;
+
+    std::string vars() override {
+        return VARS_TO_STR10(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
+    }
+
+    test_pad_ext(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne_a = {512, 512, 3, 1},
+            int lp0 = 1, int rp0 = 1, int lp1 = 1, int rp1 = 1,
+            int lp2 = 1, int rp2 = 1, int lp3 = 1, int rp3 = 1)
+        : type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3)  {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
 // GGML_OP_PAD_REFLECT_1D
 struct test_pad_reflect_1d : public test_case {
     const ggml_type type;
@@ -5623,6 +5708,32 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {5, 5, 1, 32}, {3, 4, 1, 32}, 1, 1, 0, 0, 1, 1, true));
 
+    // im2col 3D
+    test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
+    test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
+    test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
+    for (int s0 : {1, 3}) {
+        for (int s1 : {1, 3}) {
+            for (int s2 : {1, 3}) {
+                for (int p0 : {0, 3}) {
+                    for (int p1 : {0, 3}) {
+                        for (int p2 : {0, 3}) {
+                            for (int d0 : {1, 3}) {
+                                for (int d1 : {1, 3}) {
+                                    for (int d2 : {1, 3}) {
+                                        test_cases.emplace_back(new test_im2col_3d(
+                                            GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 10, 3}, {3, 3, 3, 3},
+                                            3, s0, s1, s2, p0, p1, p2, d0, d1, d2));
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+
 // Conv_2D test cases
 #ifdef DETAILED_TESTS
     // Probably we do not have enough time to execute these in the pipeline.
@@ -6340,6 +6451,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {9, 9, 1280, 1}));
     test_cases.emplace_back(new test_acc());
     test_cases.emplace_back(new test_pad());
+    test_cases.emplace_back(new test_pad_ext());
     test_cases.emplace_back(new test_pad_reflect_1d());
     test_cases.emplace_back(new test_roll());
     test_cases.emplace_back(new test_arange());