]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: Optimize `rms_norm_f32` kernel and its fused variants, giving 1-6% perf E2E...
authorOliver Simons <redacted>
Wed, 3 Sep 2025 17:59:16 +0000 (19:59 +0200)
committerGitHub <redacted>
Wed, 3 Sep 2025 17:59:16 +0000 (19:59 +0200)
* Add fastdiv, use it in modulo and use modulo in rms_norm_f32

Fastdiv is much faster way to do integer division, which was identified
as bottleneck in rms_norm_f32

* Support more `block_size` values in `rms_norm_f32`

This makes us more flexible in selecting the optimal threads w.r.t
paralellizing across a col vs. launch-overheads of threads and mio
throttles

* Update ggml/src/ggml-cuda/common.cuh

Co-authored-by: Johannes Gäßler <redacted>
* Replace modulo with fastmodulo in `rms_norm_f32`

* Use `BinPackArguments=true` for formating function calls

Will file a separate PR to adjust .clang-format file

* Update ggml/src/ggml-cuda/common.cuh

Co-authored-by: Johannes Gäßler <redacted>
* Use uint3 for both `fastdiv` and `fastmodulo`

The compiler seems to reliably optimize away the unused .z component in
the fastdiv use-case, see https://godbolt.org/z/rx8KPrKr3

* More constrained type declarations

Co-authored-by: Johannes Gäßler <redacted>
* Rename fastdiv and fastmodulo variables to shared variable name

As suggest by JohannesGaessler, this increases clarity of the intended
use

* Pack fastdiv/fastmodulo constants into uint2/uint3 objects

By packing constants to be used together into a struct, we are less
likely to make errors.

* Rename function parameter of fastmodulo

`modulo_consts` is more fitting/descriptive

---------

Co-authored-by: Johannes Gäßler <redacted>
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/norm.cu

index 85bc9e933bca541dda0cb3dcb866c264b1685632..a2dc26eab7e4c4e59f26a05336fa74fb3f1c8f27 100644 (file)
@@ -563,6 +563,38 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
 #endif // CUDART_VERSION >= 12050
 }
 
+// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
+// Precompute mp (m' in the paper) and L such that division
+// can be computed using a multiply (high 32b of 64b result)
+// and a shift:
+//
+// n/d = (mulhi(n, mp) + n) >> L;
+static const uint3 init_fastdiv_values(uint32_t d) {
+    // compute L = ceil(log2(d));
+    uint32_t L = 0;
+    while (L < 32 && (uint32_t{ 1 } << L) < d) {
+        L++;
+    }
+
+    uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
+    // pack divisor as well to reduce error surface
+    return make_uint3(mp, L, d);
+}
+
+static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint3 fastdiv_values) {
+    // expects fastdiv_values to contain <mp, L, divisor> in <x, y, z>
+    // fastdiv_values.z is unused and optimized away by the compiler.
+    // Compute high 32 bits of n * mp
+    const uint32_t hi = __umulhi(n, fastdiv_values.x);
+    // add n, apply bit shift
+    return (hi + n) >> fastdiv_values.y;
+}
+
+static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fastdiv_values) {
+    // expects  fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
+    return n - fastdiv(n, fastdiv_values) * fastdiv_values.z;
+}
+
 typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);
 
 static __device__ __forceinline__ float get_alibi_slope(
index d5157d958b717926497ac8fc4eb251b39e55acda..4f153c5718eadc5e2ba4a25d634f1a0c65059139 100644 (file)
@@ -105,29 +105,29 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
 }
 
 template <int block_size, bool do_multiply = false, bool do_add = false>
-static __global__ void rms_norm_f32(const float * x, float *       dst,
+static __global__ void rms_norm_f32(const float * x,
+                                    float *       dst,
                                     const int     ncols,
                                     const int64_t stride_row,
                                     const int64_t stride_channel,
                                     const int64_t stride_sample,
                                     const float   eps,
-                                    const float * mul                = nullptr,
-                                    const int64_t mul_stride_row     = 0,
-                                    const int64_t mul_stride_channel = 0,
-                                    const int64_t mul_stride_sample  = 0,
-                                    const int     mul_ncols          = 0,
-                                    const int     mul_nrows          = 0,
-                                    const int     mul_nchannels      = 0,
-                                    const int     mul_nsamples       = 0,
-                                    const float * add                = nullptr,
-                                    const int64_t add_stride_row     = 0,
-                                    const int64_t add_stride_channel = 0,
-                                    const int64_t add_stride_sample  = 0,
-                                    const int     add_ncols          = 0,
-                                    const int     add_nrows          = 0,
-                                    const int     add_nchannels      = 0,
-                                    const int     add_nsamples       = 0) {
-
+                                    const float * mul                  = nullptr,
+                                    const int64_t mul_stride_row       = 0,
+                                    const int64_t mul_stride_channel   = 0,
+                                    const int64_t mul_stride_sample    = 0,
+                                    const uint3   mul_ncols_packed     = make_uint3(0, 0, 0),
+                                    const uint3   mul_nrows_packed     = make_uint3(0, 0, 0),
+                                    const uint3   mul_nchannels_packed = make_uint3(0, 0, 0),
+                                    const uint3   mul_nsamples_packed  = make_uint3(0, 0, 0),
+                                    const float * add                  = nullptr,
+                                    const int64_t add_stride_row       = 0,
+                                    const int64_t add_stride_channel   = 0,
+                                    const int64_t add_stride_sample    = 0,
+                                    const uint3   add_ncols_packed     = make_uint3(0, 0, 0),
+                                    const uint3   add_nrows_packed     = make_uint3(0, 0, 0),
+                                    const uint3   add_nchannels_packed = make_uint3(0, 0, 0),
+                                    const uint3   add_nsamples_packed  = make_uint3(0, 0, 0)) {
     const int nrows     = gridDim.x;
     const int nchannels = gridDim.y;
 
@@ -142,16 +142,16 @@ static __global__ void rms_norm_f32(const float * x, float *       dst,
     dst += ((sample*nchannels + channel)*nrows + row)*ncols;
 
     if constexpr (do_multiply) {
-        const int mul_row = row % mul_nrows;
-        const int mul_channel = channel % mul_nchannels;
-        const int mul_sample = sample % mul_nsamples;
-        mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
+        const uint32_t mul_row     = fastmodulo(row, mul_nrows_packed);
+        const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed);
+        const uint32_t mul_sample  = fastmodulo(sample, mul_nsamples_packed);
+        mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
     }
 
     if constexpr (do_add) {
-        const int add_row     = row % add_nrows;
-        const int add_channel = channel % add_nchannels;
-        const int add_sample  = sample % add_nsamples;
+        const int add_row     = fastmodulo(row, add_nrows_packed);
+        const int add_channel = fastmodulo(channel, add_nchannels_packed);
+        const int add_sample  = fastmodulo(sample, add_nsamples_packed);
         add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
     }
 
@@ -165,15 +165,18 @@ static __global__ void rms_norm_f32(const float * x, float *       dst,
     // sum up partial sums
     tmp = warp_reduce_sum(tmp);
     if constexpr (block_size > WARP_SIZE) {
-        static_assert(block_size == 1024, "unexpected block_size");
+        static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size");
         __shared__ float s_sum[32];
-        const int warp_id = threadIdx.x / WARP_SIZE;
-        const int lane_id = threadIdx.x % WARP_SIZE;
+        const int        warp_id = tid / WARP_SIZE;
+        const int        lane_id = tid % WARP_SIZE;
         if (lane_id == 0) {
             s_sum[warp_id] = tmp;
         }
         __syncthreads();
-        tmp = s_sum[lane_id];
+        tmp = 0.0f;
+        if (lane_id < (block_size / WARP_SIZE)) {
+            tmp = s_sum[lane_id];
+        }
         tmp = warp_reduce_sum(tmp);
     }
 
@@ -182,12 +185,12 @@ static __global__ void rms_norm_f32(const float * x, float *       dst,
 
     for (int col = tid; col < ncols; col += block_size) {
         if constexpr (do_multiply && do_add) {
-            const int mul_col = col % mul_ncols;
-            const int add_col = col % add_ncols;
-            dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
+            const int mul_col = fastmodulo(col, mul_ncols_packed);
+            const int add_col = fastmodulo(col, add_ncols_packed);
+            dst[col]          = scale * x[col] * mul[mul_col] + add[add_col];
         } else if constexpr (do_multiply) {
-            const int mul_col = col % mul_ncols;
-            dst[col] = scale * x[col] * mul[mul_col];
+            const int mul_col = fastmodulo(col, mul_ncols_packed);
+            dst[col]          = scale * x[col] * mul[mul_col];
         } else {
             dst[col] = scale * x[col];
         }
@@ -354,77 +357,86 @@ static void rms_norm_f32_cuda(
         const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
     const dim3 blocks_num(nrows, nchannels, nsamples);
     if (ncols < 1024) {
-        const dim3 block_dims(WARP_SIZE, 1, 1);
-        rms_norm_f32<WARP_SIZE, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+        const dim3 block_dims(256, 1, 1);
+        rms_norm_f32<256, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
         rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     }
 }
 
-static void rms_norm_mul_f32_cuda(const float * x,
-                                  const float * mul,
-                                  const float * add,
-                                  float *       dst,
-                                  const int     ncols,
-                                  const int     nrows,
-                                  const int     nchannels,
-                                  const int     nsamples,
-                                  const int64_t stride_row,
-                                  const int64_t stride_channel,
-                                  const int64_t stride_sample,
-                                  const int64_t mul_stride_row,
-                                  const int64_t mul_stride_channel,
-                                  const int64_t mul_stride_sample,
-                                  const int     mul_ncols,
-                                  const int     mul_nrows,
-                                  const int     mul_nchannels,
-                                  const int     mul_nsamples,
-                                  const int64_t add_stride_row,
-                                  const int64_t add_stride_channel,
-                                  const int64_t add_stride_sample,
-                                  const int     add_ncols,
-                                  const int     add_nrows,
-                                  const int     add_nchannels,
-                                  const int     add_nsamples,
-                                  const float   eps,
-                                  cudaStream_t  stream) {
+static void rms_norm_mul_f32_cuda(const float *  x,
+                                  const float *  mul,
+                                  const float *  add,
+                                  float *        dst,
+                                  const int      ncols,
+                                  const int      nrows,
+                                  const int      nchannels,
+                                  const int      nsamples,
+                                  const int64_t  stride_row,
+                                  const int64_t  stride_channel,
+                                  const int64_t  stride_sample,
+                                  const int64_t  mul_stride_row,
+                                  const int64_t  mul_stride_channel,
+                                  const int64_t  mul_stride_sample,
+                                  const uint32_t mul_ncols,
+                                  const uint32_t mul_nrows,
+                                  const uint32_t mul_nchannels,
+                                  const uint32_t mul_nsamples,
+                                  const int64_t  add_stride_row,
+                                  const int64_t  add_stride_channel,
+                                  const int64_t  add_stride_sample,
+                                  const uint32_t add_ncols,
+                                  const uint32_t add_nrows,
+                                  const uint32_t add_nchannels,
+                                  const uint32_t add_nsamples,
+                                  const float    eps,
+                                  cudaStream_t   stream) {
     const dim3 blocks_num(nrows, nchannels, nsamples);
     if (mul == nullptr) {
         rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
         return;
     }
     if (add == nullptr) {
+        const uint3 mul_ncols_packed     = init_fastdiv_values(mul_ncols);
+        const uint3 mul_nrows_packed     = init_fastdiv_values(mul_nrows);
+        const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
+        const uint3 mul_nsamples_packed  = init_fastdiv_values(mul_nsamples);
         if (ncols < 1024) {
-            const dim3 block_dims(WARP_SIZE, 1, 1);
-            rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
-                ncols, stride_row, stride_channel, stride_sample, eps,
-                mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
-                mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+            const dim3 block_dims(256, 1, 1);
+            rms_norm_f32<256, true><<<blocks_num, block_dims, 0, stream>>>(
+                x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
+                mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
         } else {
             const dim3 block_dims(1024, 1, 1);
-            rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
-                ncols, stride_row, stride_channel, stride_sample, eps,
-                mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
-                mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+            rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(
+                x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
+                mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
         }
     } else {
+        const uint3 mul_ncols_packed     = init_fastdiv_values(mul_ncols);
+        const uint3 mul_nrows_packed     = init_fastdiv_values(mul_nrows);
+        const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
+        const uint3 mul_nsamples_packed  = init_fastdiv_values(mul_nsamples);
+
+        const uint3 add_ncols_packed     = init_fastdiv_values(add_ncols);
+        const uint3 add_nrows_packed     = init_fastdiv_values(add_nrows);
+        const uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels);
+        const uint3 add_nsamples_packed  = init_fastdiv_values(add_nsamples);
         if (ncols < 1024) {
-            const dim3 block_dims(WARP_SIZE, 1, 1);
-            rms_norm_f32<WARP_SIZE, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
-                ncols, stride_row, stride_channel, stride_sample, eps,
-                mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
-                mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
-                add, add_stride_row, add_stride_channel, add_stride_sample,
-                add_ncols, add_nrows, add_nchannels, add_nsamples);
+            const dim3 block_dims(256, 1, 1);
+            rms_norm_f32<256, true, true><<<blocks_num, block_dims, 0, stream>>>(
+                x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
+                mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
+                add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
+                add_nchannels_packed, add_nsamples_packed);
         } else {
             const dim3 block_dims(1024, 1, 1);
-            rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
-                ncols, stride_row, stride_channel, stride_sample, eps,
-                mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
-                mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
-                add, add_stride_row, add_stride_channel, add_stride_sample,
-                add_ncols, add_nrows, add_nchannels, add_nsamples);
+            rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(
+                x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
+                mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
+                add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
+                add_nchannels_packed, add_nsamples_packed);
         }
     }
 }