]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: Optimize PAD_REFLECT_1D (llama/15957)
authorBowen Han <redacted>
Thu, 18 Sep 2025 18:26:03 +0000 (11:26 -0700)
committerGeorgi Gerganov <redacted>
Sat, 20 Sep 2025 10:33:50 +0000 (13:33 +0300)
* CUDA: Optimize PAD_REFLECT_1D
feat: add more test cases for PAD_REFLECT_1D

* use fast_div to improve performance

* Apply suggestion from JohannesGaessler

Co-authored-by: Johannes Gäßler <redacted>
* Apply suggestion from JohannesGaessler

Co-authored-by: Johannes Gäßler <redacted>
* optimize

* use a concise expression to further speedup the cuda kernel

---------

Co-authored-by: Johannes Gäßler <redacted>
src/ggml-cuda/common.cuh
src/ggml-cuda/pad_reflect_1d.cu
tests/test-backend-ops.cpp

index 045c6d3006b2e084a90a7648e080c10db133bf45..3b1349171b263e450e6168a65a5540f8c0e109b3 100644 (file)
@@ -652,6 +652,14 @@ static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fa
     return n - fastdiv(n, fastdiv_values) * fastdiv_values.z;
 }
 
+// Calculate both division and modulo at once, returns <n/divisor, n%divisor>
+static __device__ __forceinline__ uint2 fast_div_modulo(uint32_t n, const uint3 fastdiv_values) {
+    // expects  fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
+    const uint32_t div_val = fastdiv(n, fastdiv_values);
+    const uint32_t mod_val = n - div_val * fastdiv_values.z;
+    return make_uint2(div_val, mod_val);
+}
+
 typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);
 
 static __device__ __forceinline__ float get_alibi_slope(
index 4ed34aec3d33112c09daf443c9ce30c291c6f39f..0478889da13fea01e14b3fc6341795596f3cc455 100644 (file)
@@ -1,82 +1,89 @@
 #include "pad_reflect_1d.cuh"
 
-static __global__ void pad_reflect_1d_kernel_f32(
-    const void * __restrict__ src0,
-    void * __restrict__ dst,
-    const int64_t ne0,
-    const int64_t ne00,
-    const int64_t ne01,
-    const int64_t ne02,
-    const int64_t ne03,
-    const int64_t nb00,
-    const int64_t nb01,
-    const int64_t nb02,
-    const int64_t nb03,
-    const int64_t nb0,
-    const int64_t nb1,
-    const int64_t nb2,
-    const int64_t nb3,
-    const int p0,
-    const int p1) {
-
+static __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void
+    pad_reflect_1d_kernel_f32(
+        const void * __restrict__ src0,
+        void * __restrict__       dst,
+        const int64_t             ne0,
+        const int64_t             ne00,
+        const uint3               ne01,
+        const int64_t             ne02,
+        const int64_t             ne03,
+        const int64_t             nb00,
+        const int64_t             nb01,
+        const int64_t             nb02,
+        const int64_t             nb03,
+        const int64_t             nb0,
+        const int64_t             nb1,
+        const int64_t             nb2,
+        const int64_t             nb3,
+        const int                 p0,
+        const int                 p1) {
     const int64_t i3 = blockIdx.z;
     const int64_t i2 = blockIdx.y;
-    const int64_t i1 = blockIdx.x;
 
-    if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
+    const uint2   div_mod_packed = fast_div_modulo(blockIdx.x, ne01);
+    const int64_t tile1          = div_mod_packed.y;  // i1
+    const int64_t tile0          = div_mod_packed.x;  // nth i0 tile
+    const int64_t i1             = tile1;
+    const int64_t i0             = threadIdx.x + tile0 * blockDim.x;
+
+    // ne01.z is original value of unpacked ne01 (see init_fastdiv_values in common.cuh)
+    if (i0 >= ne0 || i1 >= ne01.z || i2 >= ne02 || i3 >= ne03) {
         return;
     }
 
-    const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01;
-    char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1;
-
-    for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
-        float value;
+    const char * src0_ptr = (const char *) src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;
+    char *       dst_ptr  = (char *) dst + i3 * nb3 + i2 * nb2 + i1 * nb1;
 
-        if (i0 < p0) {
-            // Left padding - reflect
-            value = *(const float *)(src0_ptr + (p0 - i0) * nb00);
-        } else if (i0 < ne0 - p1) {
-            // Middle - copy
-            value = *(const float *)(src0_ptr + (i0 - p0) * nb00);
-        } else {
-            // Right padding - reflect
-            int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1;
-            value = *(const float *)(src0_ptr + src_idx * nb00);
-        }
+    const int64_t rel_i0 = i0 - p0;  // relative i0 in src0
+    int64_t src_idx;
 
-        *(float *)(dst_ptr + i0 * nb0) = value;
+    if (rel_i0 < 0) {
+        // Left padding - reflect
+        src_idx = -rel_i0;
+    } else if (rel_i0 < ne00) {
+        // Middle - copy
+        src_idx = rel_i0;
+    } else {
+        // Right padding - reflect
+        src_idx = 2 * ne00 - 2 - rel_i0;
     }
+    const float value               = *(const float *) (src0_ptr + src_idx * nb00);
+    *(float *) (dst_ptr + i0 * nb0) = value;
 }
 
 void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    cudaStream_t stream = ctx.stream();
+    const ggml_tensor * src0   = dst->src[0];
+    cudaStream_t        stream = ctx.stream();
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
 
     const int32_t * opts = (const int32_t *) dst->op_params;
-    const int p0 = opts[0];
-    const int p1 = opts[1];
+    const int       p0   = opts[0];
+    const int       p1   = opts[1];
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
+    const int64_t ne00        = src0->ne[0];
+    const int64_t ne01        = src0->ne[1];
+    const uint3   ne01_packed = init_fastdiv_values(ne01);
+    const int64_t ne02        = src0->ne[2];
+    const int64_t ne03        = src0->ne[3];
 
     const int64_t ne0 = dst->ne[0];
 
+    // sanity: padded length matches
     GGML_ASSERT(ne0 == ne00 + p0 + p1);
 
-    const dim3 block_dims(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1, 1);
-    const dim3 grid_dims(ne01, ne02, ne03);
+    constexpr int64_t bx     = CUDA_PAD_REFLECT_1D_BLOCK_SIZE;  // threads per block (x)
+    const int64_t     tiles0 = (ne0 + bx - 1) / bx;             // number of tiles along i0
+    // grid.x covers i1 and all tiles of i0: [ne01 * tiles0]
+    // grid.y covers i2: [ne02]
+    // grid.z covers i3: [ne03]
+    const dim3        grid_dims((unsigned) (ne01 * tiles0), (unsigned) ne02, (unsigned) ne03);
+    const dim3        block_dims((unsigned) bx, 1, 1);
 
     pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>(
-        src0->data, dst->data,
-        ne0, ne00, ne01, ne02, ne03,
-        src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
-        dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
-        p0, p1
-    );
+        src0->data, dst->data, ne0, ne00, ne01_packed, ne02, ne03, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+        dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], p0, p1);
 }
index 893c2af3137a8317343bc0a186a89079af2537e9..01cbf8753303ad0182d707f1556c5bfdcc13e6e1 100644 (file)
@@ -6507,6 +6507,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_pad());
     test_cases.emplace_back(new test_pad_ext());
     test_cases.emplace_back(new test_pad_reflect_1d());
+    test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
     test_cases.emplace_back(new test_roll());
     test_cases.emplace_back(new test_arange());
     test_cases.emplace_back(new test_timestep_embedding());
@@ -6645,6 +6646,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
     test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
 
+    test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {512, 34, 2, 1}));
+    test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 80, 1, 1}));
+    test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 80, 4, 1}));
+    test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 1, 1}));
+    test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
+
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8,  1}, {4, 1}, {0, 2, 1, 3}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8,  1}, {4, 1}, {0, 1, 2, 3}, true));