const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
cudaStream_t stream) {
- const int threads = 128;
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
if (src3_nb1 == sizeof(float)) {
// Mamba-2
if (d_state == 128) {
+ const int threads = 128;
GGML_ASSERT(d_state % threads == 0);
// NOTE: can be any power of two between 4 and 64
const int splitH = 16;
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
+ } else if (d_state == 256) { // Falcon-H1
+ const int threads = 256;
+ // NOTE: can be any power of two between 8 and 64
+ const int splitH = 16;
+ GGML_ASSERT(head_dim % splitH == 0);
+ const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
+ ssm_scan_f32_group<16, 256><<<blocks, threads, 0, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
+ src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
} else {
- GGML_ABORT("doesn't support d_state!=128.");
+ GGML_ABORT("doesn't support d_state!=(128 or 256).");
}
} else {
+ const int threads = 128;
// Mamba-1
GGML_ASSERT(n_head % threads == 0);
GGML_ASSERT(head_dim == 1);
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2
+ test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64, 8, 2, 32, 4)); // Falcon-H1
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));