]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : optimize cuda cumsum fallback kernel (#18343)
authorAadeshveer Singh <redacted>
Thu, 25 Dec 2025 04:11:13 +0000 (09:41 +0530)
committerGitHub <redacted>
Thu, 25 Dec 2025 04:11:13 +0000 (12:11 +0800)
ggml/src/ggml-cuda/cumsum.cu

index d2f2def8bdce8adc542d16d060cd0c52800705b0..0f72e33bbacc9a2c33742d28bcdd1fd300d2752b 100644 (file)
@@ -69,7 +69,7 @@ static __global__ void cumsum_cub_kernel(
 #endif // GGML_CUDA_USE_CUB
 }
 
-// Fallback kernel implementation (original)
+// Fallback kernel implementation
 template<typename T>
 static __global__ void cumsum_kernel(
         const T * src, T * dst,
@@ -86,10 +86,10 @@ static __global__ void cumsum_kernel(
     const int warps_per_block = blockDim.x / warp_size;
 
     extern __shared__ float smem[];
-    float * s_vals = smem;
-    float * s_warp_sums = smem + blockDim.x;
-    float * s_carry = smem + blockDim.x + warps_per_block;
-    float * s_chunk_total = s_carry + 1;
+    float *                 s_vals        = smem;
+    float *                 s_warp_sums   = smem + blockDim.x;
+    float *                 s_carry       = smem + blockDim.x + warps_per_block;
+    float *                 s_chunk_total = s_carry + 1;
 
     // Initialize carry
     if (tid == 0) {
@@ -107,21 +107,39 @@ static __global__ void cumsum_kernel(
     const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
     T       * dst_row = dst + i1 * s1  + i2 * s2  + i3 * s3;
 
-    for (int64_t start = 0; start < ne00; start += blockDim.x) {
-        int64_t idx = start + tid;
-        float val = (idx < ne00) ? ggml_cuda_cast<float, T>(src_row[idx]) : 0.0f;
+    // register blocking: process 4 elements per thread to hide latency
+    // and reduce synchronization overhead
+    constexpr int num_unroll = 4;
+    T             temp[num_unroll];
+
+    for (int64_t i = 0; i < ne00; i += num_unroll * blockDim.x) {
+        int64_t idx = i + tid * num_unroll;
+
+        // thread local sequential scan
+        temp[0] = (idx < ne00 ? src_row[idx] : T(0));
+#pragma unroll
+        for (int64_t j = 1; j < num_unroll; j++) {
+            temp[j] = temp[j - 1];
+            if (idx + j < ne00) {
+                temp[j] += src_row[idx + j];
+            } else {
+                temp[j] += 0;
+            }
+        }
+
+        // last emenent is sum of all values assigned to thread
+        float val = (idx < ne00) ? ggml_cuda_cast<float, T>(temp[num_unroll - 1]) : 0.0f;
 
-        // 1. Warp inclusive scan
+        // Warp inclusive scan
         val = warp_prefix_inclusive_sum<T, warp_size>(val);
         s_vals[tid] = val;
 
-        // Store warp total
         if (lane == warp_size - 1) {
             s_warp_sums[warp] = val;
         }
         __syncthreads();
 
-        // 2. Exclusive scan of warp sums (warp 0 only)
+        // Exclusive scan of warp sums (warp 0 only)
         if (warp == 0) {
             float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
             float inc = warp_prefix_inclusive_sum<T, warp_size>(w);
@@ -134,12 +152,17 @@ static __global__ void cumsum_kernel(
         }
         __syncthreads();
 
+        // write back results
         float carry = *s_carry;
-        float final_val = s_vals[tid] + s_warp_sums[warp] + carry;
-        if (idx < ne00) {
-            dst_row[idx] = ggml_cuda_cast<T, float>(final_val);
+        // calculate sum offset for this thread
+        float final_val_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1];
+
+#pragma unroll
+        for (int32_t j = 0; j < num_unroll; j++) {
+            if (idx + j < ne00) {
+                dst_row[idx + j] = temp[j] + ggml_cuda_cast<T, float>(final_val_offset);
+            }
         }
-        __syncthreads();
 
         // Update carry for next chunk
         if (tid == 0) {