if (nc == 4) {
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
+ } else if (nc == 3) {
+ ssm_conv_f32<threads, 3><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
+ dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else {
- GGML_ABORT("Only support kernel size = 4 now.");
+ GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
}
} else {
if (nc == 4) {
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
ssm_conv_long_token_f32<threads, 4, split_n_t><<<blocks, threads, 0, stream>>>(
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
+ } else if (nc == 3) {
+ const int64_t split_n_t = 32;
+ dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
+ ssm_conv_long_token_f32<threads, 3, split_n_t><<<blocks, threads, 0, stream>>>(
+ src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else {
- GGML_ABORT("Only support kernel size = 4 right now.");
+ GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
}
}
}