From: Aman Gupta Date: Thu, 25 Dec 2025 15:55:38 +0000 (+0800) Subject: cuda: optimize cumsum cub path (llama/18362) X-Git-Tag: upstream/1.8.3~88 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=8e02f0919d6eefdd2d00894025601c794b7fafdc;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp cuda: optimize cumsum cub path (llama/18362) * cuda: optimize cumsum cub path * remove heavy perf test --- diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index 0f72e33b..e82171f9 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -5,7 +5,7 @@ #include "ggml.h" #ifdef GGML_CUDA_USE_CUB -# include +# include #endif // GGML_CUDA_USE_CUB template @@ -16,12 +16,14 @@ static __global__ void cumsum_cub_kernel( const int64_t s01, const int64_t s02, const int64_t s03, const int64_t s1, const int64_t s2, const int64_t s3) { #ifdef GGML_CUDA_USE_CUB - using BlockScan = cub::BlockScan; + using BlockScanT = cub::BlockScan; - __shared__ typename BlockScan::TempStorage temp_storage; - __shared__ T block_carry; // carry from previous tile + __shared__ typename BlockScanT::TempStorage temp_storage; + __shared__ T block_carry; const int tid = threadIdx.x; + constexpr int UNROLL_FACTOR = 4; + constexpr int TILE_SIZE = BLOCK_SIZE * UNROLL_FACTOR; const int64_t i1 = blockIdx.x; const int64_t i2 = blockIdx.y; @@ -39,29 +41,38 @@ static __global__ void cumsum_cub_kernel( } __syncthreads(); - for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) { - int64_t idx = start + tid; - T x = (idx < ne00) ? src_row[idx] : T(0); + for (int64_t start = 0; start < ne00; start += TILE_SIZE) { + T items[UNROLL_FACTOR]; + T thread_sum = T(0); - T inclusive; - T block_total; - BlockScan(temp_storage).InclusiveSum(x, inclusive, block_total); +#pragma unroll + for (int i = 0; i < UNROLL_FACTOR; i++) { + int64_t idx = start + tid * UNROLL_FACTOR + i; + T val = (idx < ne00) ? src_row[idx] : T(0); + thread_sum += val; + items[i] = thread_sum; + } + // Block-wide scan on thread sums + T thread_prefix; + T block_total; + BlockScanT(temp_storage).InclusiveSum(thread_sum, thread_prefix, block_total); __syncthreads(); - T final_val = inclusive + block_carry; - - // store result - if (idx < ne00) { - dst_row[idx] = final_val; + // Add offset to each item and store + T thread_offset = thread_prefix - thread_sum + block_carry; + #pragma unroll + for (int i = 0; i < UNROLL_FACTOR; i++) { + int64_t idx = start + tid * UNROLL_FACTOR + i; + if (idx < ne00) { + dst_row[idx] = items[i] + thread_offset; + } } - __syncthreads(); - + // Update carry for next tile if (tid == 0) { block_carry += block_total; } - __syncthreads(); } #else @@ -200,7 +211,7 @@ static void cumsum_cuda( const int warps_per_block = block_size / warp_size; const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float); - if (use_cub) { + if (use_cub && ne00 >= 1024) { cumsum_cub_kernel<<>>( src, dst, ne00, ne01, ne02, ne03,