From: Oliver Simons Date: Wed, 3 Sep 2025 17:59:16 +0000 (+0200) Subject: CUDA: Optimize `rms_norm_f32` kernel and its fused variants, giving 1-6% perf E2E... X-Git-Tag: upstream/0.0.6527~157 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=661ae31c9c68201577e70278285b349a5a662caf;p=pkg%2Fggml%2Fsources%2Fllama.cpp CUDA: Optimize `rms_norm_f32` kernel and its fused variants, giving 1-6% perf E2E (#15715) * Add fastdiv, use it in modulo and use modulo in rms_norm_f32 Fastdiv is much faster way to do integer division, which was identified as bottleneck in rms_norm_f32 * Support more `block_size` values in `rms_norm_f32` This makes us more flexible in selecting the optimal threads w.r.t paralellizing across a col vs. launch-overheads of threads and mio throttles * Update ggml/src/ggml-cuda/common.cuh Co-authored-by: Johannes Gäßler * Replace modulo with fastmodulo in `rms_norm_f32` * Use `BinPackArguments=true` for formating function calls Will file a separate PR to adjust .clang-format file * Update ggml/src/ggml-cuda/common.cuh Co-authored-by: Johannes Gäßler * Use uint3 for both `fastdiv` and `fastmodulo` The compiler seems to reliably optimize away the unused .z component in the fastdiv use-case, see https://godbolt.org/z/rx8KPrKr3 * More constrained type declarations Co-authored-by: Johannes Gäßler * Rename fastdiv and fastmodulo variables to shared variable name As suggest by JohannesGaessler, this increases clarity of the intended use * Pack fastdiv/fastmodulo constants into uint2/uint3 objects By packing constants to be used together into a struct, we are less likely to make errors. * Rename function parameter of fastmodulo `modulo_consts` is more fitting/descriptive --------- Co-authored-by: Johannes Gäßler --- diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 85bc9e93..a2dc26ea 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -563,6 +563,38 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { #endif // CUDART_VERSION >= 12050 } +// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. +// Precompute mp (m' in the paper) and L such that division +// can be computed using a multiply (high 32b of 64b result) +// and a shift: +// +// n/d = (mulhi(n, mp) + n) >> L; +static const uint3 init_fastdiv_values(uint32_t d) { + // compute L = ceil(log2(d)); + uint32_t L = 0; + while (L < 32 && (uint32_t{ 1 } << L) < d) { + L++; + } + + uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1); + // pack divisor as well to reduce error surface + return make_uint3(mp, L, d); +} + +static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint3 fastdiv_values) { + // expects fastdiv_values to contain in + // fastdiv_values.z is unused and optimized away by the compiler. + // Compute high 32 bits of n * mp + const uint32_t hi = __umulhi(n, fastdiv_values.x); + // add n, apply bit shift + return (hi + n) >> fastdiv_values.y; +} + +static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fastdiv_values) { + // expects fastdiv_values to contain in (see init_fastdiv_values) + return n - fastdiv(n, fastdiv_values) * fastdiv_values.z; +} + typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v); static __device__ __forceinline__ float get_alibi_slope( diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index d5157d95..4f153c57 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -105,29 +105,29 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr } template -static __global__ void rms_norm_f32(const float * x, float * dst, +static __global__ void rms_norm_f32(const float * x, + float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, - const float * mul = nullptr, - const int64_t mul_stride_row = 0, - const int64_t mul_stride_channel = 0, - const int64_t mul_stride_sample = 0, - const int mul_ncols = 0, - const int mul_nrows = 0, - const int mul_nchannels = 0, - const int mul_nsamples = 0, - const float * add = nullptr, - const int64_t add_stride_row = 0, - const int64_t add_stride_channel = 0, - const int64_t add_stride_sample = 0, - const int add_ncols = 0, - const int add_nrows = 0, - const int add_nchannels = 0, - const int add_nsamples = 0) { - + const float * mul = nullptr, + const int64_t mul_stride_row = 0, + const int64_t mul_stride_channel = 0, + const int64_t mul_stride_sample = 0, + const uint3 mul_ncols_packed = make_uint3(0, 0, 0), + const uint3 mul_nrows_packed = make_uint3(0, 0, 0), + const uint3 mul_nchannels_packed = make_uint3(0, 0, 0), + const uint3 mul_nsamples_packed = make_uint3(0, 0, 0), + const float * add = nullptr, + const int64_t add_stride_row = 0, + const int64_t add_stride_channel = 0, + const int64_t add_stride_sample = 0, + const uint3 add_ncols_packed = make_uint3(0, 0, 0), + const uint3 add_nrows_packed = make_uint3(0, 0, 0), + const uint3 add_nchannels_packed = make_uint3(0, 0, 0), + const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) { const int nrows = gridDim.x; const int nchannels = gridDim.y; @@ -142,16 +142,16 @@ static __global__ void rms_norm_f32(const float * x, float * dst, dst += ((sample*nchannels + channel)*nrows + row)*ncols; if constexpr (do_multiply) { - const int mul_row = row % mul_nrows; - const int mul_channel = channel % mul_nchannels; - const int mul_sample = sample % mul_nsamples; - mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row; + const uint32_t mul_row = fastmodulo(row, mul_nrows_packed); + const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed); + const uint32_t mul_sample = fastmodulo(sample, mul_nsamples_packed); + mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row; } if constexpr (do_add) { - const int add_row = row % add_nrows; - const int add_channel = channel % add_nchannels; - const int add_sample = sample % add_nsamples; + const int add_row = fastmodulo(row, add_nrows_packed); + const int add_channel = fastmodulo(channel, add_nchannels_packed); + const int add_sample = fastmodulo(sample, add_nsamples_packed); add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row; } @@ -165,15 +165,18 @@ static __global__ void rms_norm_f32(const float * x, float * dst, // sum up partial sums tmp = warp_reduce_sum(tmp); if constexpr (block_size > WARP_SIZE) { - static_assert(block_size == 1024, "unexpected block_size"); + static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size"); __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = tid / WARP_SIZE; + const int lane_id = tid % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = tmp; } __syncthreads(); - tmp = s_sum[lane_id]; + tmp = 0.0f; + if (lane_id < (block_size / WARP_SIZE)) { + tmp = s_sum[lane_id]; + } tmp = warp_reduce_sum(tmp); } @@ -182,12 +185,12 @@ static __global__ void rms_norm_f32(const float * x, float * dst, for (int col = tid; col < ncols; col += block_size) { if constexpr (do_multiply && do_add) { - const int mul_col = col % mul_ncols; - const int add_col = col % add_ncols; - dst[col] = scale * x[col] * mul[mul_col] + add[add_col]; + const int mul_col = fastmodulo(col, mul_ncols_packed); + const int add_col = fastmodulo(col, add_ncols_packed); + dst[col] = scale * x[col] * mul[mul_col] + add[add_col]; } else if constexpr (do_multiply) { - const int mul_col = col % mul_ncols; - dst[col] = scale * x[col] * mul[mul_col]; + const int mul_col = fastmodulo(col, mul_ncols_packed); + dst[col] = scale * x[col] * mul[mul_col]; } else { dst[col] = scale * x[col]; } @@ -354,77 +357,86 @@ static void rms_norm_f32_cuda( const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { - const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const dim3 block_dims(256, 1, 1); + rms_norm_f32<256, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); rms_norm_f32<1024, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } -static void rms_norm_mul_f32_cuda(const float * x, - const float * mul, - const float * add, - float * dst, - const int ncols, - const int nrows, - const int nchannels, - const int nsamples, - const int64_t stride_row, - const int64_t stride_channel, - const int64_t stride_sample, - const int64_t mul_stride_row, - const int64_t mul_stride_channel, - const int64_t mul_stride_sample, - const int mul_ncols, - const int mul_nrows, - const int mul_nchannels, - const int mul_nsamples, - const int64_t add_stride_row, - const int64_t add_stride_channel, - const int64_t add_stride_sample, - const int add_ncols, - const int add_nrows, - const int add_nchannels, - const int add_nsamples, - const float eps, - cudaStream_t stream) { +static void rms_norm_mul_f32_cuda(const float * x, + const float * mul, + const float * add, + float * dst, + const int ncols, + const int nrows, + const int nchannels, + const int nsamples, + const int64_t stride_row, + const int64_t stride_channel, + const int64_t stride_sample, + const int64_t mul_stride_row, + const int64_t mul_stride_channel, + const int64_t mul_stride_sample, + const uint32_t mul_ncols, + const uint32_t mul_nrows, + const uint32_t mul_nchannels, + const uint32_t mul_nsamples, + const int64_t add_stride_row, + const int64_t add_stride_channel, + const int64_t add_stride_sample, + const uint32_t add_ncols, + const uint32_t add_nrows, + const uint32_t add_nchannels, + const uint32_t add_nsamples, + const float eps, + cudaStream_t stream) { const dim3 blocks_num(nrows, nchannels, nsamples); if (mul == nullptr) { rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream); return; } if (add == nullptr) { + const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); + const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows); + const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels); + const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); if (ncols < 1024) { - const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, - ncols, stride_row, stride_channel, stride_sample, eps, - mul, mul_stride_row, mul_stride_channel, mul_stride_sample, - mul_ncols, mul_nrows, mul_nchannels, mul_nsamples); + const dim3 block_dims(256, 1, 1); + rms_norm_f32<256, true><<>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true><<>>(x, dst, - ncols, stride_row, stride_channel, stride_sample, eps, - mul, mul_stride_row, mul_stride_channel, mul_stride_sample, - mul_ncols, mul_nrows, mul_nchannels, mul_nsamples); + rms_norm_f32<1024, true><<>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); } } else { + const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); + const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows); + const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels); + const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); + + const uint3 add_ncols_packed = init_fastdiv_values(add_ncols); + const uint3 add_nrows_packed = init_fastdiv_values(add_nrows); + const uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels); + const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples); if (ncols < 1024) { - const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, - ncols, stride_row, stride_channel, stride_sample, eps, - mul, mul_stride_row, mul_stride_channel, mul_stride_sample, - mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, - add, add_stride_row, add_stride_channel, add_stride_sample, - add_ncols, add_nrows, add_nchannels, add_nsamples); + const dim3 block_dims(256, 1, 1); + rms_norm_f32<256, true, true><<>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, + add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, + add_nchannels_packed, add_nsamples_packed); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true, true><<>>(x, dst, - ncols, stride_row, stride_channel, stride_sample, eps, - mul, mul_stride_row, mul_stride_channel, mul_stride_sample, - mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, - add, add_stride_row, add_stride_channel, add_stride_sample, - add_ncols, add_nrows, add_nchannels, add_nsamples); + rms_norm_f32<1024, true, true><<>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, + add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, + add_nchannels_packed, add_nsamples_packed); } } }