]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : optimize cuda ssm_scan using warp-level reduction (llama/18505)
authorAadeshveer Singh <redacted>
Tue, 6 Jan 2026 18:24:34 +0000 (23:54 +0530)
committerGeorgi Gerganov <redacted>
Wed, 14 Jan 2026 07:11:59 +0000 (09:11 +0200)
* ggml : optimize cuda ssm_scan using warp-level reduction

* ggml : apply code review suggestions (style, const, constexpr)

* ggml : add TODO regarding stride consistency

ggml/src/ggml-cuda/ssm-scan.cu

index 6b424381df5a7311d78b450195e5cb475f787299..c1d4e2bc8dfdc65ec3b8eb129b81e8f5318e2586 100644 (file)
@@ -114,7 +114,7 @@ __global__ void __launch_bounds__(splitD, 1)
 #endif // __clang__
 
 // assumes as many threads as d_state
-template <int splitH, int d_state>
+template <int c_factor, int d_state>
 __global__ void __launch_bounds__(d_state, 1)
     ssm_scan_f32_group(
         const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
@@ -125,20 +125,25 @@ __global__ void __launch_bounds__(d_state, 1)
         const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
         const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
 
-    const int head_idx = (blockIdx.x * splitH) / d_head;
-    const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
-    const int seq_idx = blockIdx.y;
+    const int warp     = threadIdx.x / WARP_SIZE;
+    const int lane     = threadIdx.x % WARP_SIZE;
+    const int warp_idx = blockIdx.x  * c_factor + warp;
+
+    const int head_idx =  warp_idx / d_head;
+    const int head_off = (warp_idx % d_head) * sizeof(float);
+    const int seq_idx  = blockIdx.y;
 
     const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
 
-    const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
-    const float * x_block  = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
-    const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
-    const float * A_block  = (const float *) ((const char *) src3 + head_idx * src3_nb1);
-    const float * B_block  = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
-    const float * C_block  = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
-    float *       y_block  = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
-    float *       s_block  = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
+    // TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase
+    const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
+    const float * x_warp  = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float)));
+    const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
+    const float * A_warp  = (const float *) ((const char *) src3 + head_idx * src3_nb1);
+    const float * B_warp  = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
+    const float * C_warp  = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
+    float *       y_warp  = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx;
+    float *       s_warp  = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
 
     // strides across n_seq_tokens
     const int stride_x  = src1_nb2 / sizeof(float);
@@ -147,80 +152,42 @@ __global__ void __launch_bounds__(d_state, 1)
     const int stride_C  = src5_nb2 / sizeof(float);
     const int stride_y  = n_head * d_head;
 
-    float state[splitH];
-    // for the parallel accumulation
-    __shared__ float stateC[splitH * d_state];
+    float state[c_factor];
+    float state_sum = 0.0f;
 
 #pragma unroll
-    for (int j = 0; j < splitH; j++) {
-        state[j] = s0_block[j * d_state + threadIdx.x];
+    for (int j = 0; j < c_factor; j++) {
+        state[j] = s0_warp[WARP_SIZE * j + lane];
     }
 
     for (int64_t i = 0; i < n_tok; i++) {
-        // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
-        // TODO: only calculate B and C once per head group
-        // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
-        float dt_soft_plus = dt_block[i * stride_dt];
-        if (dt_soft_plus <= 20.0f) {
-            dt_soft_plus = log1pf(expf(dt_soft_plus));
-        }
-        const float dA = expf(dt_soft_plus * A_block[0]);
-        const float B = B_block[i * stride_B + threadIdx.x];
-        const float C = C_block[i * stride_C + threadIdx.x];
+        // NOTE: dt_soft_plus, dA and x_dt have the same value for a warp here.
+        // Recalculation is intentional; sharing via shuffles/smem proved slower due to sync overhead.
+        const float dt_soft_plus = (dt_warp[i * stride_dt] <= 20.0f ? log1pf(expf(dt_warp[i * stride_dt])) : dt_warp[i * stride_dt]);
 
-        // across d_head
+        state_sum = 0.0f;
+        const float dA   = expf(dt_soft_plus * A_warp[0]);
+        const float x_dt = x_warp[i * stride_x] * dt_soft_plus;
 #pragma unroll
-        for (int j = 0; j < splitH; j++) {
-            const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
-
-            state[j] = (state[j] * dA) + (B * x_dt);
-
-            stateC[j * d_state + threadIdx.x] = state[j] * C;
+        for (int j = 0; j < c_factor; j++) {
+            const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane];
+            const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane];
+            state[j] = (state[j] * dA) + (B_val * x_dt);
+            state_sum += state[j] * C_val;
         }
 
-        __syncthreads();
-
-        // parallel accumulation for stateC
-        // TODO: simplify
-        {
-            static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
-            static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
-
-            // reduce until w matches the warp size
-            // TODO: does this work even when the physical warp size is 64?
-#pragma unroll
-            for (int w = d_state; w > WARP_SIZE; w >>= 1) {
-                // (assuming there are d_state threads)
-#pragma unroll
-                for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
-                    // TODO: check for bank conflicts
-                    const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
-                    stateC[k] += stateC[k + (w >> 1)];
-
-                }
-                __syncthreads();
-            }
-
-            static_assert(splitH >= d_state / WARP_SIZE);
+        // parallel accumulation for output
+        state_sum = warp_reduce_sum(state_sum);
 
-#pragma unroll
-            for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
-                float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
-                y = warp_reduce_sum(y);
-
-                // store the above accumulations
-                if (threadIdx.x % WARP_SIZE == 0) {
-                    const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
-                    y_block[i * stride_y + k] = y;
-                }
-            }
+        if (lane == 0) {
+            y_warp[i * stride_y] = state_sum;
         }
     }
 
     // write back the state
 #pragma unroll
-    for (int j = 0; j < splitH; j++) {
-        s_block[j * d_state + threadIdx.x] = state[j];
+    for (int j = 0; j < c_factor; j++) {
+        s_warp[WARP_SIZE * j + lane] = state[j];
     }
 }
 
@@ -231,27 +198,24 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
                               const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
                               const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
                               cudaStream_t stream) {
-    const int threads = 128;
     // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
     if (src3_nb1 == sizeof(float)) {
         // Mamba-2
         if (d_state == 128) {
-            GGML_ASSERT(d_state % threads == 0);
-            // NOTE: can be any power of two between 4 and 64
-            const int splitH = 16;
-            GGML_ASSERT(head_dim % splitH == 0);
-            const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
-            ssm_scan_f32_group<16, 128><<<blocks, threads, 0, stream>>>(
+            constexpr int threads   = 128;
+            constexpr int num_warps = threads/WARP_SIZE;
+
+            const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
+            ssm_scan_f32_group<128/WARP_SIZE, 128><<<blocks, threads, 0, stream>>>(
                     src0, src1, src2, src3, src4, src5, src6, dst,
                     src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
                     src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
         } else if (d_state == 256) { // Falcon-H1
-            const int threads = 256;
-            // NOTE: can be any power of two between 8 and 64
-            const int splitH = 16;
-            GGML_ASSERT(head_dim % splitH == 0);
-            const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
-            ssm_scan_f32_group<16, 256><<<blocks, threads, 0, stream>>>(
+            constexpr int threads   = 256;
+            constexpr int num_warps = threads/WARP_SIZE;
+
+            const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
+            ssm_scan_f32_group<256/WARP_SIZE, 256><<<blocks, threads, 0, stream>>>(
                     src0, src1, src2, src3, src4, src5, src6, dst,
                     src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
                     src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
@@ -260,6 +224,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
         }
     } else {
         // Mamba-1
+        constexpr int threads = 128;
         GGML_ASSERT(n_head % threads == 0);
         GGML_ASSERT(head_dim == 1);
         GGML_ASSERT(n_group == 1);