#endif // __clang__
// assumes as many threads as d_state
-template <int splitH, int d_state>
+template <int c_factor, int d_state>
__global__ void __launch_bounds__(d_state, 1)
ssm_scan_f32_group(
const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
- const int head_idx = (blockIdx.x * splitH) / d_head;
- const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
- const int seq_idx = blockIdx.y;
+ const int warp = threadIdx.x / WARP_SIZE;
+ const int lane = threadIdx.x % WARP_SIZE;
+ const int warp_idx = blockIdx.x * c_factor + warp;
+
+ const int head_idx = warp_idx / d_head;
+ const int head_off = (warp_idx % d_head) * sizeof(float);
+ const int seq_idx = blockIdx.y;
const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
- const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
- const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
- const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
- const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1);
- const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
- const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
- float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
- float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
+ // TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase
+ const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
+ const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float)));
+ const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
+ const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1);
+ const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
+ const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
+ float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx;
+ float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
// strides across n_seq_tokens
const int stride_x = src1_nb2 / sizeof(float);
const int stride_C = src5_nb2 / sizeof(float);
const int stride_y = n_head * d_head;
- float state[splitH];
- // for the parallel accumulation
- __shared__ float stateC[splitH * d_state];
+ float state[c_factor];
+ float state_sum = 0.0f;
#pragma unroll
- for (int j = 0; j < splitH; j++) {
- state[j] = s0_block[j * d_state + threadIdx.x];
+ for (int j = 0; j < c_factor; j++) {
+ state[j] = s0_warp[WARP_SIZE * j + lane];
}
for (int64_t i = 0; i < n_tok; i++) {
- // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
- // TODO: only calculate B and C once per head group
- // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
- float dt_soft_plus = dt_block[i * stride_dt];
- if (dt_soft_plus <= 20.0f) {
- dt_soft_plus = log1pf(expf(dt_soft_plus));
- }
- const float dA = expf(dt_soft_plus * A_block[0]);
- const float B = B_block[i * stride_B + threadIdx.x];
- const float C = C_block[i * stride_C + threadIdx.x];
+ // NOTE: dt_soft_plus, dA and x_dt have the same value for a warp here.
+ // Recalculation is intentional; sharing via shuffles/smem proved slower due to sync overhead.
+ const float dt_soft_plus = (dt_warp[i * stride_dt] <= 20.0f ? log1pf(expf(dt_warp[i * stride_dt])) : dt_warp[i * stride_dt]);
- // across d_head
+ state_sum = 0.0f;
+ const float dA = expf(dt_soft_plus * A_warp[0]);
+ const float x_dt = x_warp[i * stride_x] * dt_soft_plus;
#pragma unroll
- for (int j = 0; j < splitH; j++) {
- const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
-
- state[j] = (state[j] * dA) + (B * x_dt);
-
- stateC[j * d_state + threadIdx.x] = state[j] * C;
+ for (int j = 0; j < c_factor; j++) {
+ const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane];
+ const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane];
+ state[j] = (state[j] * dA) + (B_val * x_dt);
+ state_sum += state[j] * C_val;
}
- __syncthreads();
-
- // parallel accumulation for stateC
- // TODO: simplify
- {
- static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
- static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
-
- // reduce until w matches the warp size
- // TODO: does this work even when the physical warp size is 64?
-#pragma unroll
- for (int w = d_state; w > WARP_SIZE; w >>= 1) {
- // (assuming there are d_state threads)
-#pragma unroll
- for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
- // TODO: check for bank conflicts
- const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
- stateC[k] += stateC[k + (w >> 1)];
-
- }
- __syncthreads();
- }
-
- static_assert(splitH >= d_state / WARP_SIZE);
+ // parallel accumulation for output
+ state_sum = warp_reduce_sum(state_sum);
-#pragma unroll
- for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
- float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
- y = warp_reduce_sum(y);
-
- // store the above accumulations
- if (threadIdx.x % WARP_SIZE == 0) {
- const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
- y_block[i * stride_y + k] = y;
- }
- }
+ if (lane == 0) {
+ y_warp[i * stride_y] = state_sum;
}
}
// write back the state
#pragma unroll
- for (int j = 0; j < splitH; j++) {
- s_block[j * d_state + threadIdx.x] = state[j];
+ for (int j = 0; j < c_factor; j++) {
+ s_warp[WARP_SIZE * j + lane] = state[j];
}
}
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
cudaStream_t stream) {
- const int threads = 128;
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
if (src3_nb1 == sizeof(float)) {
// Mamba-2
if (d_state == 128) {
- GGML_ASSERT(d_state % threads == 0);
- // NOTE: can be any power of two between 4 and 64
- const int splitH = 16;
- GGML_ASSERT(head_dim % splitH == 0);
- const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
- ssm_scan_f32_group<16, 128><<<blocks, threads, 0, stream>>>(
+ constexpr int threads = 128;
+ constexpr int num_warps = threads/WARP_SIZE;
+
+ const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
+ ssm_scan_f32_group<128/WARP_SIZE, 128><<<blocks, threads, 0, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
} else if (d_state == 256) { // Falcon-H1
- const int threads = 256;
- // NOTE: can be any power of two between 8 and 64
- const int splitH = 16;
- GGML_ASSERT(head_dim % splitH == 0);
- const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
- ssm_scan_f32_group<16, 256><<<blocks, threads, 0, stream>>>(
+ constexpr int threads = 256;
+ constexpr int num_warps = threads/WARP_SIZE;
+
+ const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
+ ssm_scan_f32_group<256/WARP_SIZE, 256><<<blocks, threads, 0, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
}
} else {
// Mamba-1
+ constexpr int threads = 128;
GGML_ASSERT(n_head % threads == 0);
GGML_ASSERT(head_dim == 1);
GGML_ASSERT(n_group == 1);