#endif // GGML_CUDA_USE_CUB
}
-// Fallback kernel implementation (original)
+// Fallback kernel implementation
template<typename T>
static __global__ void cumsum_kernel(
const T * src, T * dst,
const int warps_per_block = blockDim.x / warp_size;
extern __shared__ float smem[];
- float * s_vals = smem;
- float * s_warp_sums = smem + blockDim.x;
- float * s_carry = smem + blockDim.x + warps_per_block;
- float * s_chunk_total = s_carry + 1;
+ float * s_vals = smem;
+ float * s_warp_sums = smem + blockDim.x;
+ float * s_carry = smem + blockDim.x + warps_per_block;
+ float * s_chunk_total = s_carry + 1;
// Initialize carry
if (tid == 0) {
const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
- for (int64_t start = 0; start < ne00; start += blockDim.x) {
- int64_t idx = start + tid;
- float val = (idx < ne00) ? ggml_cuda_cast<float, T>(src_row[idx]) : 0.0f;
+ // register blocking: process 4 elements per thread to hide latency
+ // and reduce synchronization overhead
+ constexpr int num_unroll = 4;
+ T temp[num_unroll];
+
+ for (int64_t i = 0; i < ne00; i += num_unroll * blockDim.x) {
+ int64_t idx = i + tid * num_unroll;
+
+ // thread local sequential scan
+ temp[0] = (idx < ne00 ? src_row[idx] : T(0));
+#pragma unroll
+ for (int64_t j = 1; j < num_unroll; j++) {
+ temp[j] = temp[j - 1];
+ if (idx + j < ne00) {
+ temp[j] += src_row[idx + j];
+ } else {
+ temp[j] += 0;
+ }
+ }
+
+ // last emenent is sum of all values assigned to thread
+ float val = (idx < ne00) ? ggml_cuda_cast<float, T>(temp[num_unroll - 1]) : 0.0f;
- // 1. Warp inclusive scan
+ // Warp inclusive scan
val = warp_prefix_inclusive_sum<T, warp_size>(val);
s_vals[tid] = val;
- // Store warp total
if (lane == warp_size - 1) {
s_warp_sums[warp] = val;
}
__syncthreads();
- // 2. Exclusive scan of warp sums (warp 0 only)
+ // Exclusive scan of warp sums (warp 0 only)
if (warp == 0) {
float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
float inc = warp_prefix_inclusive_sum<T, warp_size>(w);
}
__syncthreads();
+ // write back results
float carry = *s_carry;
- float final_val = s_vals[tid] + s_warp_sums[warp] + carry;
- if (idx < ne00) {
- dst_row[idx] = ggml_cuda_cast<T, float>(final_val);
+ // calculate sum offset for this thread
+ float final_val_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1];
+
+#pragma unroll
+ for (int32_t j = 0; j < num_unroll; j++) {
+ if (idx + j < ne00) {
+ dst_row[idx + j] = temp[j] + ggml_cuda_cast<T, float>(final_val_offset);
+ }
}
- __syncthreads();
// Update carry for next chunk
if (tid == 0) {