]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda: refactored ssm_scan and use CUB (#13291)
authorDavid Zhao <redacted>
Sat, 9 Aug 2025 18:29:43 +0000 (13:29 -0500)
committerGitHub <redacted>
Sat, 9 Aug 2025 18:29:43 +0000 (20:29 +0200)
* cuda: refactored ssm_scan to use CUB

* fixed compilation error when when not using CUB

* assign L to constant and use size_t instead of int

* deduplicated functions

* change min blocks per mp to 1

* Use cub load and store warp transpose

* suppress clang warning

ggml/src/ggml-cuda/ssm-scan.cu

index c9184398b422c7f3c1e1938692db24d40dbacb1f..dc9a7d58d057c14f6090081e3a872a45a11328c8 100644 (file)
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
+#define USE_CUB
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
+
+#ifdef USE_CUB
+#include <cub/cub.cuh>
+using namespace cub;
+#endif // USE_CUB
+
 #include "ssm-scan.cuh"
 
-template <size_t splitD, size_t N>
-__global__ void __launch_bounds__(splitD, 2)
-    ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
-                 const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
+// We would like to keep pragma unroll for cases where L_template is not 0,
+// so we suppress the clang transformation warning.
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wpass-failed"
+#endif // __clang__
+template <size_t splitD, size_t N, size_t L_template>
+__global__ void __launch_bounds__(splitD, 1)
+    ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
+                 const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
                  const int32_t * __restrict__ src6, float * __restrict__ dst,
                  const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
                  const int src2_nb1, const int src2_nb2, const int src3_nb1,
                  const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
-                 const int64_t s_off, const int64_t d_inner, const int64_t L) {
-
-    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-    const int bidx = blockIdx.x;  // split along B (sequences)
-    const int bidy = blockIdx.y;  // split along D (d_inner)
-    const int tid  = threadIdx.x;
-    const int wid  = tid / 32;
-    const int wtid = tid % 32;
-
-    extern __shared__ float smem[];
-    const int               stride_sA  = N + 1;
-    const int               stride_ss0 = N + 1;
-    float *                 smem_A     = smem;
-    float *                 smem_s0    = smem_A + splitD * stride_sA;
-
-    const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
-    const float * x_block  = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
-    const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
-    const float * A_block  = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
-    const float * B_block  = (const float *) ((const char *) src4 + (bidx * src4_nb3));
-    const float * C_block  = (const float *) ((const char *) src5 + (bidx * src5_nb3));
-    float *       y_block  = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
-    float *       s_block  = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
-
-    const int stride_s0 = src0_nb2 / sizeof(float);
-    const int stride_x  = src1_nb2 / sizeof(float);
+                 const int64_t s_off, const int64_t d_inner, const int64_t L_param)
+{
+    const size_t L = L_template == 0 ? L_param : L_template;
+    const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2);
+    const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float));
+    const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));
+    const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1);
+    const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb3));
+    const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb3));
+    float *y_block = (float *)((char *)dst + (blockIdx.x * d_inner * L * sizeof(float)) + blockIdx.y * splitD * sizeof(float));
+    float *s_block = (float *)((char *)dst + s_off + blockIdx.x * src0_nb3 + blockIdx.y * splitD * src0_nb2);
+
+    const int stride_x = src1_nb2 / sizeof(float);
     const int stride_dt = src2_nb1 / sizeof(float);
-    const int stride_A  = src3_nb1 / sizeof(float);
-    const int stride_B  = src4_nb2 / sizeof(float);
-    const int stride_C  = src5_nb2 / sizeof(float);
-    const int stride_s  = stride_s0;
-    const int stride_y  = d_inner;
+    const int stride_B = src4_nb2 / sizeof(float);
+    const int stride_C = src5_nb2 / sizeof(float);
+    const int stride_y = d_inner;
 
-    // can N not be 16? for example 32?
-    if (N == 16) {
-#pragma unroll
-        for (size_t i = 0; i < splitD / 4; i += 2) {
-            float value = A_block[(wid * warp_size + i) * stride_A + wtid];
-            // todo: bank conflict
-            // I am always confused with how to use the swizzling method to solve
-            // bank conflit. Hoping somebody can tell me.
-            smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
-        }
+    float regA[N];
+    float regs0[N];
+
+    __shared__ float smemB[N];
+    __shared__ float smemC[N];
+
+#ifdef USE_CUB
+    using BlockLoad = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
+    using BlockStore = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>;
+
+    union CubTempStorage {
+        typename BlockLoad::TempStorage load_temp;
+        typename BlockStore::TempStorage store_temp;
+    };
+    __shared__ CubTempStorage cub_temp_storage;
+
+    BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA);
+    BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0);
+#else
+    const int stride_s0 = src0_nb2 / sizeof(float);
+    const int stride_A = src3_nb1 / sizeof(float);
 #pragma unroll
-        for (size_t i = 0; i < splitD / 4; i += 2) {
-            float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
-            smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
-        }
+    for (size_t n = 0; n < N; ++n)
+    {
+        regA[n] = A_block[threadIdx.x * stride_A + n];
+        regs0[n] = s0_block[threadIdx.x * stride_s0 + n];
     }
+#endif
 
-    __syncthreads();
+#pragma unroll
+    for (size_t i = 0; i < L; i++)
+    {
+        if (threadIdx.x < N)
+        {
+            smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x];
+            smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x];
+        }
+        __syncthreads();
 
-    for (int64_t i = 0; i < L; i++) {
-        float dt_soft_plus = dt_block[i * stride_dt + tid];
-        if (dt_soft_plus <= 20.0f) {
-            dt_soft_plus = log1pf(exp(dt_soft_plus));
+        float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x];
+        if (dt_soft_plus <= 20.0f)
+        {
+            dt_soft_plus = log1pf(expf(dt_soft_plus));
         }
-        float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
+        float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus;
+
         float sumf = 0.0f;
 #pragma unroll
-        for (size_t j = 0; j < N; j++) {
-            float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
-                          (B_block[i * stride_B + j] * x_dt);
-            sumf += state * C_block[i * stride_C + j];
-            if (i == L - 1) {
-                s_block[tid * stride_s + j] = state;
-            } else {
-                smem_s0[tid * stride_ss0 + j] = state;
-            }
+        for (size_t n = 0; n < N; n++)
+        {
+            float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt;
+            sumf += state * smemC[n];
+            regs0[n] = state;
         }
-        __syncthreads();
-        y_block[i * stride_y + tid] = sumf;
+        y_block[i * stride_y + threadIdx.x] = sumf;
     }
+
+#ifdef USE_CUB
+    BlockStore(cub_temp_storage.store_temp).Store(s_block, regs0);
+#else
+    const int stride_s = stride_s0;
+#pragma unroll
+    for (size_t n = 0; n < N; ++n)
+    {
+        s_block[threadIdx.x * stride_s + n] = regs0[n];
+    }
+#endif
 }
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif // __clang__
 
 // assumes as many threads as d_state
 template <int splitH, int d_state>
@@ -201,11 +231,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
                               const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
                               const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
                               cudaStream_t stream) {
+    const int threads = 128;
     // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
     if (src3_nb1 == sizeof(float)) {
         // Mamba-2
         if (d_state == 128) {
-            const int threads = 128;
             GGML_ASSERT(d_state % threads == 0);
             // NOTE: can be any power of two between 4 and 64
             const int splitH = 16;
@@ -229,7 +259,6 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
             GGML_ABORT("doesn't support d_state!=(128 or 256).");
         }
     } else {
-        const int threads = 128;
         // Mamba-1
         GGML_ASSERT(n_head % threads == 0);
         GGML_ASSERT(head_dim == 1);
@@ -237,10 +266,63 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
         const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
         const int  smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
         if (d_state == 16) {
-            ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
-                src0, src1, src2, src3, src4, src5, src6, dst,
+            switch (n_tok)
+            {
+            case 1:
+                ssm_scan_f32<threads, 16, 1><<<blocks, threads, smem_size, stream>>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
+                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+                break;
+            case 2:
+                ssm_scan_f32<threads, 16, 2><<<blocks, threads, smem_size, stream>>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
+                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+                break;
+            case 3:
+                ssm_scan_f32<threads, 16, 3><<<blocks, threads, smem_size, stream>>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
+                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+                break;
+            case 4:
+                ssm_scan_f32<threads, 16, 4><<<blocks, threads, smem_size, stream>>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
+                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+                break;
+            case 5:
+                ssm_scan_f32<threads, 16, 5><<<blocks, threads, smem_size, stream>>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
                 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
                 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+                break;
+            case 6:
+                ssm_scan_f32<threads, 16, 6><<<blocks, threads, smem_size, stream>>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
+                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+                break;
+            case 7:
+                ssm_scan_f32<threads, 16, 7><<<blocks, threads, smem_size, stream>>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
+                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+                break;
+            case 8:
+                ssm_scan_f32<threads, 16, 8><<<blocks, threads, smem_size, stream>>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
+                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+                break;
+            default:
+                ssm_scan_f32<threads, 16, 0><<<blocks, threads, smem_size, stream>>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
+                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+                break;
+            }
         } else {
             GGML_ABORT("doesn't support d_state!=16.");
         }