const int threads = 128;
GGML_ASSERT(nr % threads == 0);
- if (n_t <= 32) {
- const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
- if (nc == 4) {
- ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
- dst, dst_nb0, dst_nb1, dst_nb2, n_t);
- } else if (nc == 3) {
- ssm_conv_f32<threads, 3><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
- dst, dst_nb0, dst_nb1, dst_nb2, n_t);
+ auto launch_kernel = [&](auto NC) {
+ constexpr int kNC = decltype(NC)::value;
+ if (n_t <= 32) {
+ const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
+ ssm_conv_f32<threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
+ dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else {
- GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
- }
- } else {
- if (nc == 4) {
- const int64_t split_n_t = 32;
- dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
- ssm_conv_long_token_f32<threads, 4, split_n_t><<<blocks, threads, 0, stream>>>(
- src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
- } else if (nc == 3) {
const int64_t split_n_t = 32;
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
- ssm_conv_long_token_f32<threads, 3, split_n_t><<<blocks, threads, 0, stream>>>(
+ ssm_conv_long_token_f32<threads, kNC, split_n_t><<<blocks, threads, 0, stream>>>(
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
- } else {
- GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
}
+ };
+
+ switch (nc) {
+ case 3: launch_kernel(std::integral_constant<int, 3>{}); break;
+ case 4: launch_kernel(std::integral_constant<int, 4>{}); break;
+ case 9: launch_kernel(std::integral_constant<int, 9>{}); break;
+ default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now.");
}
}