+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
+#define USE_CUB
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
+
+#ifdef USE_CUB
+#include <cub/cub.cuh>
+using namespace cub;
+#endif // USE_CUB
+
#include "ssm-scan.cuh"
-template <size_t splitD, size_t N>
-__global__ void __launch_bounds__(splitD, 2)
- ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
- const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
+// We would like to keep pragma unroll for cases where L_template is not 0,
+// so we suppress the clang transformation warning.
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wpass-failed"
+#endif // __clang__
+template <size_t splitD, size_t N, size_t L_template>
+__global__ void __launch_bounds__(splitD, 1)
+ ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
+ const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
const int32_t * __restrict__ src6, float * __restrict__ dst,
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
- const int64_t s_off, const int64_t d_inner, const int64_t L) {
-
- constexpr int warp_size = ggml_cuda_get_physical_warp_size();
- const int bidx = blockIdx.x; // split along B (sequences)
- const int bidy = blockIdx.y; // split along D (d_inner)
- const int tid = threadIdx.x;
- const int wid = tid / 32;
- const int wtid = tid % 32;
-
- extern __shared__ float smem[];
- const int stride_sA = N + 1;
- const int stride_ss0 = N + 1;
- float * smem_A = smem;
- float * smem_s0 = smem_A + splitD * stride_sA;
-
- const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
- const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
- const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
- const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
- const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
- const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
- float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
- float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
-
- const int stride_s0 = src0_nb2 / sizeof(float);
- const int stride_x = src1_nb2 / sizeof(float);
+ const int64_t s_off, const int64_t d_inner, const int64_t L_param)
+{
+ const size_t L = L_template == 0 ? L_param : L_template;
+ const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2);
+ const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float));
+ const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));
+ const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1);
+ const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb3));
+ const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb3));
+ float *y_block = (float *)((char *)dst + (blockIdx.x * d_inner * L * sizeof(float)) + blockIdx.y * splitD * sizeof(float));
+ float *s_block = (float *)((char *)dst + s_off + blockIdx.x * src0_nb3 + blockIdx.y * splitD * src0_nb2);
+
+ const int stride_x = src1_nb2 / sizeof(float);
const int stride_dt = src2_nb1 / sizeof(float);
- const int stride_A = src3_nb1 / sizeof(float);
- const int stride_B = src4_nb2 / sizeof(float);
- const int stride_C = src5_nb2 / sizeof(float);
- const int stride_s = stride_s0;
- const int stride_y = d_inner;
+ const int stride_B = src4_nb2 / sizeof(float);
+ const int stride_C = src5_nb2 / sizeof(float);
+ const int stride_y = d_inner;
- // can N not be 16? for example 32?
- if (N == 16) {
-#pragma unroll
- for (size_t i = 0; i < splitD / 4; i += 2) {
- float value = A_block[(wid * warp_size + i) * stride_A + wtid];
- // todo: bank conflict
- // I am always confused with how to use the swizzling method to solve
- // bank conflit. Hoping somebody can tell me.
- smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
- }
+ float regA[N];
+ float regs0[N];
+
+ __shared__ float smemB[N];
+ __shared__ float smemC[N];
+
+#ifdef USE_CUB
+ using BlockLoad = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
+ using BlockStore = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>;
+
+ union CubTempStorage {
+ typename BlockLoad::TempStorage load_temp;
+ typename BlockStore::TempStorage store_temp;
+ };
+ __shared__ CubTempStorage cub_temp_storage;
+
+ BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA);
+ BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0);
+#else
+ const int stride_s0 = src0_nb2 / sizeof(float);
+ const int stride_A = src3_nb1 / sizeof(float);
#pragma unroll
- for (size_t i = 0; i < splitD / 4; i += 2) {
- float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
- smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
- }
+ for (size_t n = 0; n < N; ++n)
+ {
+ regA[n] = A_block[threadIdx.x * stride_A + n];
+ regs0[n] = s0_block[threadIdx.x * stride_s0 + n];
}
+#endif
- __syncthreads();
+#pragma unroll
+ for (size_t i = 0; i < L; i++)
+ {
+ if (threadIdx.x < N)
+ {
+ smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x];
+ smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x];
+ }
+ __syncthreads();
- for (int64_t i = 0; i < L; i++) {
- float dt_soft_plus = dt_block[i * stride_dt + tid];
- if (dt_soft_plus <= 20.0f) {
- dt_soft_plus = log1pf(exp(dt_soft_plus));
+ float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x];
+ if (dt_soft_plus <= 20.0f)
+ {
+ dt_soft_plus = log1pf(expf(dt_soft_plus));
}
- float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
+ float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus;
+
float sumf = 0.0f;
#pragma unroll
- for (size_t j = 0; j < N; j++) {
- float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
- (B_block[i * stride_B + j] * x_dt);
- sumf += state * C_block[i * stride_C + j];
- if (i == L - 1) {
- s_block[tid * stride_s + j] = state;
- } else {
- smem_s0[tid * stride_ss0 + j] = state;
- }
+ for (size_t n = 0; n < N; n++)
+ {
+ float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt;
+ sumf += state * smemC[n];
+ regs0[n] = state;
}
- __syncthreads();
- y_block[i * stride_y + tid] = sumf;
+ y_block[i * stride_y + threadIdx.x] = sumf;
}
+
+#ifdef USE_CUB
+ BlockStore(cub_temp_storage.store_temp).Store(s_block, regs0);
+#else
+ const int stride_s = stride_s0;
+#pragma unroll
+ for (size_t n = 0; n < N; ++n)
+ {
+ s_block[threadIdx.x * stride_s + n] = regs0[n];
+ }
+#endif
}
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif // __clang__
// assumes as many threads as d_state
template <int splitH, int d_state>
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
cudaStream_t stream) {
+ const int threads = 128;
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
if (src3_nb1 == sizeof(float)) {
// Mamba-2
if (d_state == 128) {
- const int threads = 128;
GGML_ASSERT(d_state % threads == 0);
// NOTE: can be any power of two between 4 and 64
const int splitH = 16;
GGML_ABORT("doesn't support d_state!=(128 or 256).");
}
} else {
- const int threads = 128;
// Mamba-1
GGML_ASSERT(n_head % threads == 0);
GGML_ASSERT(head_dim == 1);
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
if (d_state == 16) {
- ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
- src0, src1, src2, src3, src4, src5, src6, dst,
+ switch (n_tok)
+ {
+ case 1:
+ ssm_scan_f32<threads, 16, 1><<<blocks, threads, smem_size, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+ break;
+ case 2:
+ ssm_scan_f32<threads, 16, 2><<<blocks, threads, smem_size, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+ break;
+ case 3:
+ ssm_scan_f32<threads, 16, 3><<<blocks, threads, smem_size, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+ break;
+ case 4:
+ ssm_scan_f32<threads, 16, 4><<<blocks, threads, smem_size, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+ break;
+ case 5:
+ ssm_scan_f32<threads, 16, 5><<<blocks, threads, smem_size, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+ break;
+ case 6:
+ ssm_scan_f32<threads, 16, 6><<<blocks, threads, smem_size, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+ break;
+ case 7:
+ ssm_scan_f32<threads, 16, 7><<<blocks, threads, smem_size, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+ break;
+ case 8:
+ ssm_scan_f32<threads, 16, 8><<<blocks, threads, smem_size, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+ break;
+ default:
+ ssm_scan_f32<threads, 16, 0><<<blocks, threads, smem_size, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+ break;
+ }
} else {
GGML_ABORT("doesn't support d_state!=16.");
}