]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: Optimize `reduce_rows_f32` kernel, leading up to 25x perf improvement on kernel...
authorOliver Simons <redacted>
Wed, 13 Aug 2025 08:04:46 +0000 (10:04 +0200)
committerGitHub <redacted>
Wed, 13 Aug 2025 08:04:46 +0000 (10:04 +0200)
* Factor out `reduce_rows_f32` from common.cuh

This increases iteration cycle speed by not having to recompile
every kernel all the time

* Hide memory-latency by loop unrolling in reduce_rows_f32

* Further optimizations to `reduce_rows_f32`

1. Increase threadblock size to better hide latency of memory requests.
   As a consequence of bigger threadblocks, do 2-step summation, using
   shared memory to communicate results between invocations
2. Use sum_temp array to reduce waits on sum
3. Adjust num_unroll to reflext bigger threadblock
4. Improve default block_dims, increase support for more block_dims

* Add perf tests for `reduce_rows_f32` kernel

* Add heuristic to toggle 128/512 threads based on sm count

Break even point was the minimum of the following multiples.

| GPU Model                     | Nrow SM Count Multiple |
| -----------                   | -----------            |
| RTX 4000 SFF ADA              | 2.0x                   |
| RTX 6000 ADA                  | 2.5x                   |
| RTX PRO 6000 Blackwell Max-Q  | 3.04x                  |
| RTX PRO 4500 Blackwell | 3.15x                  |

* Ensure perf gains also for small ncols and large nrows

Alternative to this, one could have also made the number of unrollings
template-able, but that would require compiling the kernel multiple
times, increasing binary size unnecessarily

* Modify perf and unit-tests

* Apply auto-formatting by clang

* Fix CI build failure

See https://github.com/ggml-org/llama.cpp/actions/runs/16798370266/job/47573716079?pr=15132#step:7:486
Building with VS generator worked though.

* Remove sm_count property from `ggml_backend_cuda_context`

Requested by @JohannesGaessler, and should fix remaining CI issues as a
side-effect

* Add CUB-based implementation for GGML_OP_MEAN

Currently this branch is only executed for nrows==1

* Add heuristics to execute CUB branch only when it brings perf

Heuristics were determined on the following HW:

* RTX 4000 SFF ADA
* RTX 6000 ADA
* RTX PRO 6000 Blackwell Max-Q
* RTX PRO 4500 Blackwell

* Add unit-test for CUB-based mean

Tests should run with CUDA Graphs enabled per default on NVGPUs

* Rename `USE_CUB` to `GGML_CUDA_USE_CUB`

Suggested by @JohannesGaessler

* Unindent Preprocessor directives

See
https://github.com/ggml-org/llama.cpp/pull/15132#discussion_r2269213506

ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/mean.cu
ggml/src/ggml-cuda/reduce_rows.cuh [new file with mode: 0644]
ggml/src/ggml-cuda/sum.cu
ggml/src/ggml-cuda/sumrows.cu
tests/test-backend-ops.cpp

index a23da57e3a1dcb9c12d44e6c3cb5b27dbc921ed0..5a2a3478d26502e6c32e83a1f7120d50fa1fdb14 100644 (file)
 #define GGML_CUDA_CC_IS_QY2(cc)      (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
 #define GGML_CUDA_CC_IS_NG(cc)       (cc >= GGML_CUDA_CC_NG)
 
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
+#    define GGML_CUDA_USE_CUB
+#endif  // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
+
 #ifdef __CUDA_ARCH_LIST__
 constexpr bool ggml_cuda_has_arch_impl(int) {
     return false;
@@ -420,26 +424,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
 #endif // FP16_AVAILABLE
 }
 
-// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
-template<bool norm>
-static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
-    const int row = blockIdx.x;
-    const int col = threadIdx.x;
-
-    float sum = 0.0f;
-    for (int i = col; i < ncols; i += blockDim.x) {
-        sum += x[row * ncols + i];
-    }
-
-    sum = warp_reduce_sum(sum);
-
-    if (col != 0) {
-        return;
-    }
-
-    dst[row] = norm ? sum / ncols : sum;
-}
-
 template<int width = WARP_SIZE>
 static __device__ __forceinline__ int warp_reduce_all(int x) {
 #ifdef GGML_USE_HIP
index 4b238a3998ba3a136a9fb62abe872fac1a479f5d..2ad493239b1dbf088220e744c760c23c14968eb5 100644 (file)
@@ -1,4 +1,14 @@
 #include "mean.cuh"
+#include "reduce_rows.cuh"
+
+#ifdef GGML_CUDA_USE_CUB
+#include <cub/cub.cuh>
+using namespace cub;
+#endif  // GGML_CUDA_USE_CUB
+
+template <typename T> __global__ void divide_by_count(T * result, size_t count) {
+    *result /= static_cast<T>(count);
+}
 
 void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0   = dst->src[0];
@@ -13,7 +23,45 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const int64_t ncols = src0->ne[0];
     const int64_t nrows = ggml_nrows(src0);
 
-    const dim3 block_dims(WARP_SIZE, 1, 1);
+// Special case for reducing vectors
+#ifdef GGML_CUDA_USE_CUB
+    cudaStreamCaptureStatus iscapturing;
+    CUDA_CHECK(cudaStreamIsCapturing(stream, &iscapturing));
+    if ((nrows == 1) &&
+            // CUDA_GRAPHS_DISABLED
+            ((ncols > 65536) &&
+             ((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
+              ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
+              ctx.cuda_graph->disable_due_to_failed_graph_capture)) ||
+        // CUDA_GRAPHS ENABLED
+        ((ncols > 32768) &&
+         !((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
+           ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
+           ctx.cuda_graph->disable_due_to_failed_graph_capture))) {
+        // Single row - use device-wide reduction
+        size_t           tmp_size = 0;
+        ggml_cuda_pool & pool     = ctx.pool();
+
+        DeviceReduce::Sum(nullptr, tmp_size, src0_d, dst_d, ncols, stream);
+
+        ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
+        DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, src0_d, dst_d, ncols, stream);
+
+        // Divide by ncols
+        divide_by_count<float><<<1, 1, 0, stream>>>(dst_d, ncols);
+        return;
+    }
+#endif
+
     const dim3 block_nums(nrows, 1, 1);
-    reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
+
+    const int id  = ggml_cuda_get_device();
+    const int nsm = ggml_cuda_info().devices[id].nsm;
+    if ((nrows / nsm) < 2) {
+        const dim3 block_dims(512, 1, 1);
+        reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
+    } else {
+        const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
+        reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
+    }
 }
diff --git a/ggml/src/ggml-cuda/reduce_rows.cuh b/ggml/src/ggml-cuda/reduce_rows.cuh
new file mode 100644 (file)
index 0000000..6bee204
--- /dev/null
@@ -0,0 +1,53 @@
+#include "common.cuh"
+
+// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
+template <bool norm>
+static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) {
+    const int row = blockIdx.x;
+    const int col = threadIdx.x;
+
+    float     sum        = 0.0f;
+    const int num_unroll = 8;
+    float     temp[num_unroll];
+    float     sum_temp[num_unroll] = { 0.0f };
+    for (int i = col; i < ncols;) {
+        for (int j = 0; j < num_unroll; ++j) {
+            if (i < ncols) {
+                temp[j] = x[row * ncols + i];
+            } else {
+                temp[j] = 0;
+            }
+            i += blockDim.x;
+        }
+        for (int j = 0; j < num_unroll; ++j) {
+            sum_temp[j] += temp[j];
+        }
+    }
+    for (int j = 0; j < num_unroll; ++j) {
+        sum += sum_temp[j];
+    }
+
+    // sum up partial sums
+    sum = warp_reduce_sum(sum);
+    if (blockDim.x > WARP_SIZE) {
+        assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
+        __shared__ float s_sum[32];
+        const int        warp_id = threadIdx.x / WARP_SIZE;
+        const int        lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = sum;
+        }
+        __syncthreads();
+        sum = 0.0f;
+        if (lane_id < (blockDim.x / WARP_SIZE)) {
+            sum = s_sum[lane_id];
+        }
+        sum = warp_reduce_sum(sum);
+    }
+
+    if (col != 0) {
+        return;
+    }
+
+    dst[row] = norm ? sum / ncols : sum;
+}
index eb3d7cdba98a7ae4b36433e14fe28975b4b3931a..c56257b440661f5c38d54a308a82ecdfa728ca76 100644 (file)
@@ -1,19 +1,15 @@
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
-#define USE_CUB
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
+#include "sum.cuh"
+#include "sumrows.cuh"
 
-#ifdef USE_CUB
+#ifdef GGML_CUDA_USE_CUB
 #include <cub/cub.cuh>
 using namespace cub;
-#endif // USE_CUB
-
-#include "sumrows.cuh"
-#include "sum.cuh"
+#endif  // GGML_CUDA_USE_CUB
 
 #include <cstdint>
 
 void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) {
-#ifdef USE_CUB
+#ifdef GGML_CUDA_USE_CUB
     size_t tmp_size = 0;
     DeviceReduce::Sum(nullptr,       tmp_size, x, dst, ne, stream);
     ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
@@ -23,7 +19,7 @@ void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int
     // For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14.
     sum_rows_f32_cuda(x, dst, ne, 1, stream);
     GGML_UNUSED(pool);
-#endif // USE_CUB
+#endif // GGML_CUDA_USE_CUB
 }
 
 void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
index 2eee08fa073754e994960d4779e825478beb8efb..4025771aadb9db8aed0c5a86973619f5b02ca9b7 100644 (file)
@@ -1,9 +1,17 @@
+#include "reduce_rows.cuh"
 #include "sumrows.cuh"
 
 void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    const dim3 block_dims(WARP_SIZE, 1, 1);
+    const int  id  = ggml_cuda_get_device();
+    const int  nsm = ggml_cuda_info().devices[id].nsm;
     const dim3 block_nums(nrows, 1, 1);
-    reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+    if ((nrows / nsm) < 2) {
+        const dim3 block_dims(512, 1, 1);
+        reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+    } else {
+        const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
+        reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+    }
 }
 
 void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -19,8 +27,17 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const int64_t ncols = src0->ne[0];
     const int64_t nrows = ggml_nrows(src0);
 
-    const dim3 block_dims(WARP_SIZE, 1, 1);
     const dim3 block_nums(nrows, 1, 1);
 
-    reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
+    const int id  = ggml_cuda_get_device();
+    const int nsm = ggml_cuda_info().devices[id].nsm;
+    if ((nrows / nsm) < 2) {
+        // Increase num threads to 512 for small nrows to better hide the latency
+        const dim3 block_dims(512, 1, 1);
+        reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
+    } else {
+        // Enough active SMs to hide latency, use smaller blocks to allow better scheduling
+        const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
+        reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
+    }
 }
index d29779cd12b229eddef9504b495c0e41028b4653..63e03978e4292964a4aeccce381dbc458c8dba85 100644 (file)
@@ -5998,6 +5998,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_sum());
     test_cases.emplace_back(new test_sum_rows());
     test_cases.emplace_back(new test_mean());
+    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }));
+    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 }));
+    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 }));
+    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
+    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
+    test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
+    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 }));
+    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
+    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
     test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
     test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
     test_cases.emplace_back(new test_acc());
@@ -6179,6 +6188,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
         test_cases.emplace_back(new test_add_id(GGML_TYPE_F32, GGML_TYPE_F32, 2880, 32, 4, n_token));
     }
 
+    std::vector<std::array<int64_t, 4>> reduce_rows_cases = {
+        { 8192, 1,    1, 1 },
+        { 8192, 8192, 1, 1 },
+        { 128,  8192, 1, 1 },
+    };
+
+    for (auto it: reduce_rows_cases){
+        test_cases.emplace_back(new test_mean(GGML_TYPE_F32, it));
+        test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, it));
+        test_cases.emplace_back(new test_sum(GGML_TYPE_F32, it));
+    }
+
     return test_cases;
 }