GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
- // allows optimizing the modulo since n_group should be a power of 2
- GGML_ASSERT((ng & -ng) == ng);
+ GGML_ASSERT(nh % ng == 0);
// heads per thread
const int dh = (nh + nth - 1)/nth;
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
const float dA = expf(dt_soft_plus * A[h]);
+ const int g = h / (nh / ng); // repeat_interleave
// dim
for (int i1 = 0; i1 < nr; ++i1) {
// TODO: maybe unroll more?
for (int j = 0; j < 1; j++) {
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
- GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
- GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
+ GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
+ GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
t0 = GGML_F32_VEC_MUL(t0, adA);
t1 = GGML_F32_VEC_MUL(t1, axdt);
for (int i = 0; i < np; i += GGML_F32_STEP) {
for (int j = 0; j < GGML_F32_ARR; j++) {
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
- ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
- az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
+ ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);
+ az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
// d_state
for (int i0 = np; i0 < nc; ++i0) {
const int i = i0 + ii*nc;
- const int ig = i0 + (h & (ng - 1))*nc;
+ const int ig = i0 + g*nc;
// state = prev_state * dA + dB * x
const float state = (s0[i] * dA) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C)
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
+ const int g = h / (nh / ng); // repeat_interleave
// dim
for (int i1 = 0; i1 < nr; ++i1) {
// TODO: what happens when (d_state % svcntw()) != 0?
for (int64_t k = 0; k < nc; k += svcntw()) {
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
- svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
- svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
+ svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
+ svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
// d_state
for (int i0 = 0; i0 < nc; ++i0) {
const int i = i0 + ii*nc;
- const int ig = i0 + (h & (ng - 1))*nc;
+ const int ig = i0 + g*nc;
// state = prev_state * dA + dB * x
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C)
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
const int64_t i = i0 + i1*nc;
+ const int64_t g = ir / (nh / ng); // repeat_interleave
float s0 = s0_buff[i];
float s = s_buff[i];
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
- device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
- device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
+ device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
+ device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
const int64_t i = i0 + i1*nc;
+ const int64_t g = ir / (nh / ng); // repeat_interleave
float s0 = s0_buff[i];
float s = s_buff[i];
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
- device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
- device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
+ device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
+ device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
for (int64_t i2 = 0; i2 < n_t; ++i2) {