From: Bowen Han Date: Thu, 18 Sep 2025 18:26:03 +0000 (-0700) Subject: CUDA: Optimize PAD_REFLECT_1D (#15957) X-Git-Tag: upstream/0.0.6527~14 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=38dbdf4c057515ccea9bec0ca2518f86d5e4d28e;p=pkg%2Fggml%2Fsources%2Fllama.cpp CUDA: Optimize PAD_REFLECT_1D (#15957) * CUDA: Optimize PAD_REFLECT_1D feat: add more test cases for PAD_REFLECT_1D * use fast_div to improve performance * Apply suggestion from JohannesGaessler Co-authored-by: Johannes Gäßler * Apply suggestion from JohannesGaessler Co-authored-by: Johannes Gäßler * optimize * use a concise expression to further speedup the cuda kernel --------- Co-authored-by: Johannes Gäßler --- diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 045c6d30..3b134917 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -652,6 +652,14 @@ static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fa return n - fastdiv(n, fastdiv_values) * fastdiv_values.z; } +// Calculate both division and modulo at once, returns +static __device__ __forceinline__ uint2 fast_div_modulo(uint32_t n, const uint3 fastdiv_values) { + // expects fastdiv_values to contain in (see init_fastdiv_values) + const uint32_t div_val = fastdiv(n, fastdiv_values); + const uint32_t mod_val = n - div_val * fastdiv_values.z; + return make_uint2(div_val, mod_val); +} + typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v); static __device__ __forceinline__ float get_alibi_slope( diff --git a/ggml/src/ggml-cuda/pad_reflect_1d.cu b/ggml/src/ggml-cuda/pad_reflect_1d.cu index 4ed34aec..0478889d 100644 --- a/ggml/src/ggml-cuda/pad_reflect_1d.cu +++ b/ggml/src/ggml-cuda/pad_reflect_1d.cu @@ -1,82 +1,89 @@ #include "pad_reflect_1d.cuh" -static __global__ void pad_reflect_1d_kernel_f32( - const void * __restrict__ src0, - void * __restrict__ dst, - const int64_t ne0, - const int64_t ne00, - const int64_t ne01, - const int64_t ne02, - const int64_t ne03, - const int64_t nb00, - const int64_t nb01, - const int64_t nb02, - const int64_t nb03, - const int64_t nb0, - const int64_t nb1, - const int64_t nb2, - const int64_t nb3, - const int p0, - const int p1) { - +static __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void + pad_reflect_1d_kernel_f32( + const void * __restrict__ src0, + void * __restrict__ dst, + const int64_t ne0, + const int64_t ne00, + const uint3 ne01, + const int64_t ne02, + const int64_t ne03, + const int64_t nb00, + const int64_t nb01, + const int64_t nb02, + const int64_t nb03, + const int64_t nb0, + const int64_t nb1, + const int64_t nb2, + const int64_t nb3, + const int p0, + const int p1) { const int64_t i3 = blockIdx.z; const int64_t i2 = blockIdx.y; - const int64_t i1 = blockIdx.x; - if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) { + const uint2 div_mod_packed = fast_div_modulo(blockIdx.x, ne01); + const int64_t tile1 = div_mod_packed.y; // i1 + const int64_t tile0 = div_mod_packed.x; // nth i0 tile + const int64_t i1 = tile1; + const int64_t i0 = threadIdx.x + tile0 * blockDim.x; + + // ne01.z is original value of unpacked ne01 (see init_fastdiv_values in common.cuh) + if (i0 >= ne0 || i1 >= ne01.z || i2 >= ne02 || i3 >= ne03) { return; } - const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01; - char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1; - - for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { - float value; + const char * src0_ptr = (const char *) src0 + i3 * nb03 + i2 * nb02 + i1 * nb01; + char * dst_ptr = (char *) dst + i3 * nb3 + i2 * nb2 + i1 * nb1; - if (i0 < p0) { - // Left padding - reflect - value = *(const float *)(src0_ptr + (p0 - i0) * nb00); - } else if (i0 < ne0 - p1) { - // Middle - copy - value = *(const float *)(src0_ptr + (i0 - p0) * nb00); - } else { - // Right padding - reflect - int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1; - value = *(const float *)(src0_ptr + src_idx * nb00); - } + const int64_t rel_i0 = i0 - p0; // relative i0 in src0 + int64_t src_idx; - *(float *)(dst_ptr + i0 * nb0) = value; + if (rel_i0 < 0) { + // Left padding - reflect + src_idx = -rel_i0; + } else if (rel_i0 < ne00) { + // Middle - copy + src_idx = rel_i0; + } else { + // Right padding - reflect + src_idx = 2 * ne00 - 2 - rel_i0; } + const float value = *(const float *) (src0_ptr + src_idx * nb00); + *(float *) (dst_ptr + i0 * nb0) = value; } void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - cudaStream_t stream = ctx.stream(); + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); const int32_t * opts = (const int32_t *) dst->op_params; - const int p0 = opts[0]; - const int p1 = opts[1]; + const int p0 = opts[0]; + const int p1 = opts[1]; - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const uint3 ne01_packed = init_fastdiv_values(ne01); + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; const int64_t ne0 = dst->ne[0]; + // sanity: padded length matches GGML_ASSERT(ne0 == ne00 + p0 + p1); - const dim3 block_dims(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1, 1); - const dim3 grid_dims(ne01, ne02, ne03); + constexpr int64_t bx = CUDA_PAD_REFLECT_1D_BLOCK_SIZE; // threads per block (x) + const int64_t tiles0 = (ne0 + bx - 1) / bx; // number of tiles along i0 + // grid.x covers i1 and all tiles of i0: [ne01 * tiles0] + // grid.y covers i2: [ne02] + // grid.z covers i3: [ne03] + const dim3 grid_dims((unsigned) (ne01 * tiles0), (unsigned) ne02, (unsigned) ne03); + const dim3 block_dims((unsigned) bx, 1, 1); pad_reflect_1d_kernel_f32<<>>( - src0->data, dst->data, - ne0, ne00, ne01, ne02, ne03, - src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], - dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], - p0, p1 - ); + src0->data, dst->data, ne0, ne00, ne01_packed, ne02, ne03, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], p0, p1); } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 893c2af3..01cbf875 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6507,6 +6507,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_pad_ext()); test_cases.emplace_back(new test_pad_reflect_1d()); + test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1})); test_cases.emplace_back(new test_roll()); test_cases.emplace_back(new test_arange()); test_cases.emplace_back(new test_timestep_embedding()); @@ -6645,6 +6646,12 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1})); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1})); + test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {512, 34, 2, 1})); + test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 80, 1, 1})); + test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 80, 4, 1})); + test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 1, 1})); + test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, true));