]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cuda: optimize cumsum cub path (llama/18362)
authorAman Gupta <redacted>
Thu, 25 Dec 2025 15:55:38 +0000 (23:55 +0800)
committerGeorgi Gerganov <redacted>
Wed, 31 Dec 2025 10:39:43 +0000 (12:39 +0200)
* cuda: optimize cumsum cub path

* remove heavy perf test

src/ggml-cuda/cumsum.cu

index 0f72e33bbacc9a2c33742d28bcdd1fd300d2752b..e82171f9c2ab2d0a0b6c5a0561a586c3952c5534 100644 (file)
@@ -5,7 +5,7 @@
 #include "ggml.h"
 
 #ifdef GGML_CUDA_USE_CUB
-#   include <cub/device/device_scan.cuh>
+#   include <cub/block/block_scan.cuh>
 #endif // GGML_CUDA_USE_CUB
 
 template<typename T, int BLOCK_SIZE>
@@ -16,12 +16,14 @@ static __global__ void cumsum_cub_kernel(
         const int64_t  s01, const int64_t  s02, const int64_t  s03,
         const int64_t   s1,  const int64_t   s2,  const int64_t   s3) {
 #ifdef GGML_CUDA_USE_CUB
-    using BlockScan = cub::BlockScan<T, BLOCK_SIZE>;
+    using BlockScanT = cub::BlockScan<T, BLOCK_SIZE>;
 
-    __shared__ typename BlockScan::TempStorage temp_storage;
-    __shared__ T block_carry;      // carry from previous tile
+    __shared__ typename BlockScanT::TempStorage temp_storage;
+    __shared__ T block_carry;
 
     const int tid = threadIdx.x;
+    constexpr int UNROLL_FACTOR = 4;
+    constexpr int TILE_SIZE = BLOCK_SIZE * UNROLL_FACTOR;
 
     const int64_t i1 = blockIdx.x;
     const int64_t i2 = blockIdx.y;
@@ -39,29 +41,38 @@ static __global__ void cumsum_cub_kernel(
     }
     __syncthreads();
 
-    for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) {
-        int64_t idx = start + tid;
-        T x = (idx < ne00) ? src_row[idx] : T(0);
+    for (int64_t start = 0; start < ne00; start += TILE_SIZE) {
+        T items[UNROLL_FACTOR];
+        T thread_sum = T(0);
 
-        T inclusive;
-        T block_total;
-        BlockScan(temp_storage).InclusiveSum(x, inclusive, block_total);
+#pragma unroll
+        for (int i = 0; i < UNROLL_FACTOR; i++) {
+            int64_t idx = start + tid * UNROLL_FACTOR + i;
+            T val = (idx < ne00) ? src_row[idx] : T(0);
+            thread_sum += val;
+            items[i] = thread_sum;
+        }
 
+        // Block-wide scan on thread sums
+        T thread_prefix;
+        T block_total;
+        BlockScanT(temp_storage).InclusiveSum(thread_sum, thread_prefix, block_total);
         __syncthreads();
 
-        T final_val = inclusive + block_carry;
-
-        // store result
-        if (idx < ne00) {
-            dst_row[idx] = final_val;
+        // Add offset to each item and store
+        T thread_offset = thread_prefix - thread_sum + block_carry;
+        #pragma unroll
+        for (int i = 0; i < UNROLL_FACTOR; i++) {
+            int64_t idx = start + tid * UNROLL_FACTOR + i;
+            if (idx < ne00) {
+                dst_row[idx] = items[i] + thread_offset;
+            }
         }
 
-        __syncthreads();
-
+        // Update carry for next tile
         if (tid == 0) {
             block_carry += block_total;
         }
-
         __syncthreads();
     }
 #else
@@ -200,7 +211,7 @@ static void cumsum_cuda(
     const int warps_per_block = block_size / warp_size;
     const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
 
-    if (use_cub) {
+    if (use_cub && ne00 >= 1024) {
         cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0, stream>>>(
             src, dst,
             ne00, ne01, ne02, ne03,