#include "ggml.h"
#ifdef GGML_CUDA_USE_CUB
-# include <cub/device/device_scan.cuh>
+# include <cub/block/block_scan.cuh>
#endif // GGML_CUDA_USE_CUB
template<typename T, int BLOCK_SIZE>
const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t s1, const int64_t s2, const int64_t s3) {
#ifdef GGML_CUDA_USE_CUB
- using BlockScan = cub::BlockScan<T, BLOCK_SIZE>;
+ using BlockScanT = cub::BlockScan<T, BLOCK_SIZE>;
- __shared__ typename BlockScan::TempStorage temp_storage;
- __shared__ T block_carry; // carry from previous tile
+ __shared__ typename BlockScanT::TempStorage temp_storage;
+ __shared__ T block_carry;
const int tid = threadIdx.x;
+ constexpr int UNROLL_FACTOR = 4;
+ constexpr int TILE_SIZE = BLOCK_SIZE * UNROLL_FACTOR;
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.y;
}
__syncthreads();
- for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) {
- int64_t idx = start + tid;
- T x = (idx < ne00) ? src_row[idx] : T(0);
+ for (int64_t start = 0; start < ne00; start += TILE_SIZE) {
+ T items[UNROLL_FACTOR];
+ T thread_sum = T(0);
- T inclusive;
- T block_total;
- BlockScan(temp_storage).InclusiveSum(x, inclusive, block_total);
+#pragma unroll
+ for (int i = 0; i < UNROLL_FACTOR; i++) {
+ int64_t idx = start + tid * UNROLL_FACTOR + i;
+ T val = (idx < ne00) ? src_row[idx] : T(0);
+ thread_sum += val;
+ items[i] = thread_sum;
+ }
+ // Block-wide scan on thread sums
+ T thread_prefix;
+ T block_total;
+ BlockScanT(temp_storage).InclusiveSum(thread_sum, thread_prefix, block_total);
__syncthreads();
- T final_val = inclusive + block_carry;
-
- // store result
- if (idx < ne00) {
- dst_row[idx] = final_val;
+ // Add offset to each item and store
+ T thread_offset = thread_prefix - thread_sum + block_carry;
+ #pragma unroll
+ for (int i = 0; i < UNROLL_FACTOR; i++) {
+ int64_t idx = start + tid * UNROLL_FACTOR + i;
+ if (idx < ne00) {
+ dst_row[idx] = items[i] + thread_offset;
+ }
}
- __syncthreads();
-
+ // Update carry for next tile
if (tid == 0) {
block_carry += block_total;
}
-
__syncthreads();
}
#else
const int warps_per_block = block_size / warp_size;
const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
- if (use_cub) {
+ if (use_cub && ne00 >= 1024) {
cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,