]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cuda : extend GGML_OP_PAD to work with non-cont src0 (llama/19429)
authorGeorgi Gerganov <redacted>
Tue, 10 Feb 2026 06:07:16 +0000 (08:07 +0200)
committerGeorgi Gerganov <redacted>
Sat, 14 Feb 2026 22:20:18 +0000 (00:20 +0200)
* cuda : extend GGML_OP_PAD to work with non-cont src0

* tests : add permuted pad

src/ggml-cpu/ops.cpp
src/ggml-cuda/ggml-cuda.cu
src/ggml-cuda/pad.cu
tests/test-backend-ops.cpp

index ce15b18ce0ee97c4b2f4ddeebd33070bac68052a..ed45350207ecd5cc3332001c12058310a462251f 100644 (file)
@@ -7629,8 +7629,7 @@ static void ggml_compute_forward_pad_f32(
 
     const ggml_tensor * src0 = dst->src[0];
 
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-    GGML_ASSERT( dst->nb[0] == sizeof(float));
+    assert(dst->nb[0] == sizeof(float));
 
     const int ith = params->ith;
     const int nth = params->nth;
index 9e77c231c85f2a45fc8946edc0753476c8a32e8f..b163468789fc7adb7a74b70dbc5104ec0cbe35e4 100644 (file)
@@ -4834,8 +4834,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_SUM_ROWS:
         case GGML_OP_MEAN:
         case GGML_OP_GROUP_NORM:
-        case GGML_OP_PAD:
             return ggml_is_contiguous(op->src[0]);
+        case GGML_OP_PAD:
+            return true;
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD_REFLECT_1D:
         case GGML_OP_ARANGE:
index 660c192e48a868386861b35829bf62f77c808123..31cd00f77816fe6f5ffdda0428289e2aa1a25f9b 100644 (file)
@@ -7,7 +7,7 @@ __device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) {
     return (coord + size) % size;
 }
 
-static __global__ void pad_f32(const float * src, float * dst,
+static __global__ void pad_f32(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst,
                                const int lp0, const int rp0, const int lp1, const int rp1,
                                const int lp2, const int rp2, const int lp3, const int rp3,
                                const int ne0, const int ne1, const int ne2, const int ne3,
@@ -34,11 +34,8 @@ static __global__ void pad_f32(const float * src, float * dst,
             const int64_t i01  = i1 - lp1;
             const int64_t i02  = i2 - lp2;
             const int64_t i03  = i3 - lp3;
-            const int64_t ne02 = ne2 - lp2 - rp2;
-            const int64_t ne01 = ne1 - lp1 - rp1;
-            const int64_t ne00 = ne0 - lp0 - rp0;
 
-            const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
+            const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00;
 
             dst[dst_idx] = src[src_idx];
         } else {
@@ -57,21 +54,21 @@ static __global__ void pad_f32(const float * src, float * dst,
         const int64_t i02 = wrap_around(i2 - lp2, ne02);
         const int64_t i03 = wrap_around(i3 - lp3, ne03);
 
-        const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
+        const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00;
 
         dst[dst_idx] = src[src_idx];
     }
 }
 
 
-static void pad_f32_cuda(const float * src, float * dst,
+static void pad_f32_cuda(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst,
     const int lp0, const int rp0, const int lp1, const int rp1,
     const int lp2, const int rp2, const int lp3, const int rp3,
     const int ne0, const int ne1, const int ne2, const int ne3,
     const bool circular, cudaStream_t stream) {
     int  num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
     dim3 gridDim(num_blocks, ne1, ne2 * ne3);
-    pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst,
+    pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, s00, s01, s02, s03, dst,
                                                          lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
                                                          ne0, ne1, ne2, ne3, circular);
 }
@@ -82,9 +79,10 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     float *             dst_d  = (float *) dst->data;
     cudaStream_t        stream = ctx.stream();
 
+    GGML_TENSOR_UNARY_OP_LOCALS;
+
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
-    GGML_ASSERT(ggml_is_contiguous(src0));
 
     const int32_t lp0      = ((const int32_t *) (dst->op_params))[0];
     const int32_t rp0      = ((const int32_t *) (dst->op_params))[1];
@@ -96,7 +94,12 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const int32_t rp3      = ((const int32_t *) (dst->op_params))[7];
     const int32_t circular = ((const int32_t *) (dst->op_params))[8];
 
-    pad_f32_cuda(src0_d, dst_d,
+    const size_t s00 = nb00 / ggml_type_size(src0->type);
+    const size_t s01 = nb01 / ggml_type_size(src0->type);
+    const size_t s02 = nb02 / ggml_type_size(src0->type);
+    const size_t s03 = nb03 / ggml_type_size(src0->type);
+
+    pad_f32_cuda(src0_d, s00, s01, s02, s03, dst_d,
                  lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
                  dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
                  (bool) circular, stream);
index 6fe1780f3bafae4d7f0915873b4909bfa0af4d02..56dadb9b36febe2724d3764724cf181bf4d4177c 100644 (file)
@@ -5894,33 +5894,36 @@ struct test_pad_ext : public test_case {
     const int rp2;
     const int lp3;
     const int rp3;
-    const bool v;
+    const int tfrm; // 0 - none, 1 - non-cont, 2 - perm
     const bool circular;
 
     std::string vars() override {
-        return VARS_TO_STR12(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, v, circular);
+        return VARS_TO_STR12(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, tfrm, circular);
     }
 
     test_pad_ext(ggml_type type = GGML_TYPE_F32,
             std::array<int64_t, 4> ne_a = {512, 512, 3, 1},
             int lp0 = 1, int rp0 = 1, int lp1 = 1, int rp1 = 1,
             int lp2 = 1, int rp2 = 1, int lp3 = 1, int rp3 = 1,
-            bool v = false, bool circular = false)
+            int tfrm = 0, bool circular = false)
         : type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3),
-          v(v), circular(circular) {}
+          tfrm(tfrm), circular(circular) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
         ggml_set_name(a, "a");
 
-        if (v) {
+        if (tfrm == 1) {
             a = ggml_view_4d(ctx, a, (a->ne[0] + 1) / 2, (a->ne[1] + 1) / 2, (a->ne[2] + 1) / 2, (a->ne[3] + 1) / 2, a->nb[1], a->nb[2], a->nb[3], 0);
             ggml_set_name(a, "view of a");
+        } else if (tfrm == 2) {
+            a = ggml_permute(ctx, a, 2, 1, 0, 3);
+            ggml_set_name(a, "permuted a");
         }
 
         ggml_tensor * out = circular
             ? ggml_pad_ext_circular(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3)
-            : ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
+            : ggml_pad_ext         (ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3);
         ggml_set_name(out, "out");
 
         return out;
@@ -8198,10 +8201,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 200, 64, 4, 4 }));
     test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 384, 64, 4, 4 }));
 
-    for (bool v : {false, true}) {
+    for (int tfrm : {0, 1, 2}) {
         for (bool circular : {false, true}) {
-            test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v, circular));
-            test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v, circular));
+            test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, tfrm, circular));
+            test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, tfrm, circular));
         }
     }