struct ggml_tensor * dt,
struct ggml_tensor * A,
struct ggml_tensor * B,
- struct ggml_tensor * C);
+ struct ggml_tensor * C,
+ struct ggml_tensor * ids);
// partition into non-overlapping windows with padding if needed
// example:
static void ggml_compute_forward_ssm_scan_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
- const ggml_tensor * src0 = dst->src[0]; // s
- const ggml_tensor * src1 = dst->src[1]; // x
- const ggml_tensor * src2 = dst->src[2]; // dt
- const ggml_tensor * src3 = dst->src[3]; // A
- const ggml_tensor * src4 = dst->src[4]; // B
- const ggml_tensor * src5 = dst->src[5]; // C
+ const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
+ const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
+ const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
+ const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
+ const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
+ const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
+ const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
const int ith = params->ith;
const int nth = params->nth;
- const int64_t nc = src0->ne[0]; // d_state
- const int64_t nr = src0->ne[1]; // d_inner
- const int64_t n_t = src1->ne[1]; // number of tokens per sequence
- const int64_t n_s = src0->ne[2]; // number of sequences in the batch
+ const int64_t nc = src0->ne[0]; // d_state
+ const int64_t nr = src0->ne[1]; // dim
+ const int64_t nh = src1->ne[1]; // n_head
+ const int64_t ng = src4->ne[1];
+ const int64_t nt = src1->ne[2]; // number of tokens per sequence
+ const int64_t ns = src1->ne[3]; // number of sequences in the batch
- GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
+ // can't use ggml_nbytes because src1 is not necessarily contiguous
+ const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
+
+ GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
- // required for the dot product between s and C
- GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
- // required for per-sequence offsets for states
- GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
- // required to get correct offset for state destination (i.e. src1->nb[3])
- GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
+ GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
+ // allows optimizing the modulo since n_group should be a power of 2
+ GGML_ASSERT((ng & -ng) == ng);
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
+ // heads per thread
+ const int dh = (nh + nth - 1)/nth;
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
- const int ir = ir1 - ir0;
+ // head range for this thread
+ const int ih0 = dh*ith;
+ const int ih1 = MIN(ih0 + dh, nh);
+
+ const int32_t * ids = (const int32_t *) src6->data;
+
+ for (int i3 = 0; i3 < ns; ++i3) {
+ const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
+ float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
+
+ for (int i2 = 0; i2 < nt; ++i2) {
+ const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
+ const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
+ const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
+ float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
+
+ if (src3->ne[0] == 1) {
+ // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
+
+ // n_head
+ for (int h = ih0; h < ih1; ++h) {
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
+ const float dA = expf(dt_soft_plus * A[h]);
+
+ // dim
+ for (int i1 = 0; i1 < nr; ++i1) {
+ const int ii = i1 + h*nr;
+ const float x_dt = x[ii] * dt_soft_plus;
+ float sumf = 0.0f;
+#if defined(GGML_SIMD)
+ #if defined(__ARM_FEATURE_SVE)
+ const int ggml_f32_epr = svcntw();
+ const int ggml_f32_step = 1 * ggml_f32_epr;
+
+ const int np = (nc & ~(ggml_f32_step - 1));
+
+ GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
+
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
+
+ for (int i = 0; i < np; i += ggml_f32_step) {
+ // TODO: maybe unroll more?
+ for (int j = 0; j < 1; j++) {
+ GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
+ GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
+ GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
+
+ t0 = GGML_F32_VEC_MUL(t0, adA);
+ t1 = GGML_F32_VEC_MUL(t1, axdt);
+
+ t0 = GGML_F32_VEC_ADD(t0, t1);
+
+ sum = GGML_F32_VEC_FMA(sum, t0, t2);
+
+ GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
+ }
+ }
+
+ sumf = GGML_F32xt_REDUCE_ONE(sum);
+ #else
+ const int np = (nc & ~(GGML_F32_STEP - 1));
+
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
+
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
+
+ GGML_F32_VEC ax[GGML_F32_ARR];
+ GGML_F32_VEC ay[GGML_F32_ARR];
+ GGML_F32_VEC az[GGML_F32_ARR];
+
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
+ for (int j = 0; j < GGML_F32_ARR; j++) {
+ ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
+ ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
+ az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
+
+ ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
+ ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
+
+ ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
- #ifdef __ARM_FEATURE_SVE
- for (int i3 = 0; i3 < n_s; ++i3) {
- for (int i2 = 0; i2 < n_t; ++i2) {
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
-
- // use the output as the source for the next token-wise iterations
- if (i2 > 0) { s0 = s; }
-
- // d_inner
- for (int i1 = 0; i1 < ir; ++i1) {
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
- float x_dt = x[i1] * dt_soft_plus;
- svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
- svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
- svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
-
- for (int64_t k = 0; k < nc; k += svcntw()) {
- svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
- svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
- svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
- svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
-
- svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
- t1 = exp_ps_sve(svptrue_b32(), t1);
- svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
-
- vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
- r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
-
- GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
+ sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
+
+ GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
+ }
+ }
+
+ // reduce sum0..sum3 to sum0
+ GGML_F32_VEC_REDUCE(sumf, sum);
+ #endif
+#else
+ const int np = 0;
+#endif
+ // d_state
+ for (int i0 = np; i0 < nc; ++i0) {
+ const int i = i0 + ii*nc;
+ const int ig = i0 + (h & (ng - 1))*nc;
+ // state = prev_state * dA + dB * x
+ const float state = (s0[i] * dA) + (B[ig] * x_dt);
+ // y = rowwise_dotprod(state, C)
+ sumf += state * C[ig];
+ s[i] = state;
+ }
+ y[ii] = sumf;
}
- y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
}
- }
- }
- #else
- for (int i3 = 0; i3 < n_s; ++i3) {
- for (int i2 = 0; i2 < n_t; ++i2) {
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
-
- // use the output as the source for the next token-wise iterations
- if (i2 > 0) { s0 = s; }
-
- // d_inner
- for (int i1 = 0; i1 < ir; ++i1) {
- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
- float x_dt = x[i1] * dt_soft_plus;
- float sumf = 0.0f;
- // d_state
- for (int i0 = 0; i0 < nc; ++i0) {
- int i = i0 + i1*nc;
- // state = prev_state * dA + dB * x
- float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
- // y = rowwise_dotprod(state, C)
- sumf += state * C[i0];
- s[i] = state;
+ } else {
+ // Mamba-1 has an element-wise decay factor for the states
+
+ // n_head
+ for (int h = ih0; h < ih1; ++h) {
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
+
+ // dim
+ for (int i1 = 0; i1 < nr; ++i1) {
+ const int ii = i1 + h*nr;
+ const float x_dt = x[ii] * dt_soft_plus;
+#if defined(__ARM_FEATURE_SVE)
+ svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
+ svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
+ svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
+
+ // d_state
+ // TODO: what happens when (d_state % svcntw()) != 0?
+ for (int64_t k = 0; k < nc; k += svcntw()) {
+ svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
+ svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
+ svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
+ svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
+
+ svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
+ t1 = exp_ps_sve(svptrue_b32(), t1);
+ svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
+
+ vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
+ r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
+
+ GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
+ }
+ y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
+#else
+ float sumf = 0.0f;
+ // NOTE: can't really use GGML_SIMD here because d_state is usually 16
+ // and also because expf is used within the loop.
+ // d_state
+ for (int i0 = 0; i0 < nc; ++i0) {
+ const int i = i0 + ii*nc;
+ const int ig = i0 + (h & (ng - 1))*nc;
+ // state = prev_state * dA + dB * x
+ const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
+ // y = rowwise_dotprod(state, C)
+ sumf += state * C[ig];
+ s[i] = state;
+ }
+ y[ii] = sumf;
+#endif
}
- y[i1] = sumf;
}
}
+ // use the output as the source when it's not the first token-wise iteration
+ s0 = s;
}
- #endif
+ }
}
void ggml_compute_forward_ssm_scan(
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
-#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c)
+#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
for (int i = 0; i < np; i += ggml_f32_step) {
ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i);
- sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
+ sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
- sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2);
+ sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2);
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
- sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3);
+ sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3);
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
- sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4);
+ sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4);
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
- sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5);
+ sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5);
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
- sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6);
+ sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6);
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
- sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7);
+ sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7);
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
- sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8);
+ sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8);
}
// leftovers
// Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
for (int i = np; i < np2; i += ggml_f32_epr) {
ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i);
- sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
+ sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
}
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
if (np2 < n) {
ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i);
- ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
+ ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
GGML_F32_VEC_STORE(y + i, ay1);
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
- ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2);
+ ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx);
GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
- ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3);
+ ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx);
GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
- ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4);
+ ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx);
GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
- ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5);
+ ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx);
GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
- ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6);
+ ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx);
GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
- ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7);
+ ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx);
GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
- ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8);
+ ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx);
GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
}
for (int i = np; i < np2; i += ggml_f32_epr) {
ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i);
- ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
+ ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
GGML_F32_VEC_STORE(y + i, ay1);
}
case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_LOG:
- case GGML_OP_SSM_SCAN:
- case GGML_OP_SSM_CONV:
return true;
+ case GGML_OP_SSM_SCAN: {
+ if (op->src[3]->ne[0] == 1) {
+ // Mamba2
+ // (kernel only supports d_state == 128 && d_head % 16 == 0)
+ return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
+ } else {
+ // Mamba
+ // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
+ return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
+ }
+ }
+ case GGML_OP_SSM_CONV: {
+ // assumes d_inner % threads == 0
+ return op->src[0]->ne[1] % 128 == 0;
+ }
case GGML_OP_CONT:
return op->src[0]->type != GGML_TYPE_BF16;
case GGML_OP_DIAG_MASK_INF:
__global__ void __launch_bounds__(splitD, 2)
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
- const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2,
- const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
- const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
- float * __restrict__ dst, const int64_t L) {
- GGML_UNUSED(src1_nb0);
- GGML_UNUSED(src2_nb0);
+ const int32_t * __restrict__ src6, float * __restrict__ dst,
+ const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
+ const int src2_nb1, const int src2_nb2, const int src3_nb1,
+ const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
+ const int64_t s_off, const int64_t d_inner, const int64_t L) {
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
- const int bidx = blockIdx.x; // split along B
- const int bidy = blockIdx.y; // split along D
+ const int bidx = blockIdx.x; // split along B (sequences)
+ const int bidy = blockIdx.y; // split along D (d_inner)
const int tid = threadIdx.x;
const int wid = tid / 32;
const int wtid = tid % 32;
float * smem_A = smem;
float * smem_s0 = smem_A + splitD * stride_sA;
- const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
- const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
+ const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
+ const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
- const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2));
- const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2));
- float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
- float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
+ const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
+ const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
+ float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
+ float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
- const int stride_s0 = src0_nb1 / sizeof(float);
- const int stride_x = src1_nb1 / sizeof(float);
+ const int stride_s0 = src0_nb2 / sizeof(float);
+ const int stride_x = src1_nb2 / sizeof(float);
const int stride_dt = src2_nb1 / sizeof(float);
const int stride_A = src3_nb1 / sizeof(float);
- const int stride_B = src4_nb1 / sizeof(float);
- const int stride_C = src5_nb1 / sizeof(float);
+ const int stride_B = src4_nb2 / sizeof(float);
+ const int stride_C = src5_nb2 / sizeof(float);
const int stride_s = stride_s0;
- const int stride_y = stride_x;
+ const int stride_y = d_inner;
// can N not be 16? for example 32?
if (N == 16) {
}
}
+// assumes as many threads as d_state
+template <int splitH, int d_state>
+__global__ void __launch_bounds__(d_state, 1)
+ ssm_scan_f32_group(
+ const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
+ const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
+ const int32_t * __restrict__ src6, float * __restrict__ dst,
+ const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
+ const int src2_nb1, const int src2_nb2, const int src3_nb1,
+ const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
+ const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
+
+ const int head_idx = (blockIdx.x * splitH) / d_head;
+ const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
+ const int seq_idx = blockIdx.y;
+
+ const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
+
+ const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
+ const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
+ const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
+ const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1);
+ const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
+ const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
+ float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
+ float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
+
+ // strides across n_seq_tokens
+ const int stride_x = src1_nb2 / sizeof(float);
+ const int stride_dt = src2_nb1 / sizeof(float);
+ const int stride_B = src4_nb2 / sizeof(float);
+ const int stride_C = src5_nb2 / sizeof(float);
+ const int stride_y = n_head * d_head;
+
+ float state[splitH];
+ // for the parallel accumulation
+ __shared__ float stateC[splitH * d_state];
+
+#pragma unroll
+ for (int j = 0; j < splitH; j++) {
+ state[j] = s0_block[j * d_state + threadIdx.x];
+ }
+
+ for (int64_t i = 0; i < n_tok; i++) {
+ // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
+ // TODO: only calculate B and C once per head group
+ // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
+ float dt_soft_plus = dt_block[i * stride_dt];
+ if (dt_soft_plus <= 20.0f) {
+ dt_soft_plus = log1pf(expf(dt_soft_plus));
+ }
+ const float dA = expf(dt_soft_plus * A_block[0]);
+ const float B = B_block[i * stride_B + threadIdx.x];
+ const float C = C_block[i * stride_C + threadIdx.x];
+
+ // across d_head
+#pragma unroll
+ for (int j = 0; j < splitH; j++) {
+ const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
+
+ state[j] = (state[j] * dA) + (B * x_dt);
+
+ stateC[j * d_state + threadIdx.x] = state[j] * C;
+ }
+
+ __syncthreads();
+
+ // parallel accumulation for stateC
+ // TODO: simplify
+ {
+ static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
+ static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
+
+ // reduce until w matches the warp size
+ // TODO: does this work even when the physical warp size is 64?
+#pragma unroll
+ for (int w = d_state; w > WARP_SIZE; w >>= 1) {
+ // (assuming there are d_state threads)
+#pragma unroll
+ for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
+ // TODO: check for bank conflicts
+ const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
+ stateC[k] += stateC[k + (w >> 1)];
+
+ }
+ __syncthreads();
+ }
+
+ static_assert(splitH >= d_state / WARP_SIZE);
+
+#pragma unroll
+ for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
+ float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
+ y = warp_reduce_sum(y);
+
+ // store the above accumulations
+ if (threadIdx.x % WARP_SIZE == 0) {
+ const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
+ y_block[i * stride_y + k] = y;
+ }
+ }
+ }
+ }
+
+ // write back the state
+#pragma unroll
+ for (int j = 0; j < splitH; j++) {
+ s_block[j * d_state + threadIdx.x] = state[j];
+ }
+}
+
static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
- const float * src4, const float * src5, const int src0_nb1, const int src0_nb2,
- const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
- const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
- const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
- float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B,
+ const float * src4, const float * src5, const int32_t * src6, float * dst,
+ const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,
+ const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,
+ const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
+ const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
cudaStream_t stream) {
const int threads = 128;
- // todo: consider D cannot be divided,does this situation exist?
- GGML_ASSERT(D % threads == 0);
- const dim3 blocks(B, (D + threads - 1) / threads, 1);
- const int smem_size = (threads * (N + 1) * 2) * sizeof(float);
- if (N == 16) {
- ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
- src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0,
- src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
+ // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
+ if (src3_nb1 == sizeof(float)) {
+ // Mamba-2
+ if (d_state == 128) {
+ GGML_ASSERT(d_state % threads == 0);
+ // NOTE: can be any power of two between 4 and 64
+ const int splitH = 16;
+ GGML_ASSERT(head_dim % splitH == 0);
+ const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
+ ssm_scan_f32_group<16, 128><<<blocks, threads, 0, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
+ src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
+ } else {
+ GGML_ABORT("doesn't support d_state!=128.");
+ }
} else {
- GGML_ABORT("doesn't support N!=16.");
+ // Mamba-1
+ GGML_ASSERT(n_head % threads == 0);
+ GGML_ASSERT(head_dim == 1);
+ GGML_ASSERT(n_group == 1);
+ const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
+ const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
+ if (d_state == 16) {
+ ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+ } else {
+ GGML_ABORT("doesn't support d_state!=16.");
+ }
}
}
const struct ggml_tensor * src3 = dst->src[3]; // A
const struct ggml_tensor * src4 = dst->src[4]; // B
const struct ggml_tensor * src5 = dst->src[5]; // C
-
- // const int64_t d_state = src0->ne[0];
- // const int64_t d_inner = src0->ne[1];
- // const int64_t l = src1->ne[1];
- // const int64_t b = src0->ne[2];
+ const struct ggml_tensor * src6 = dst->src[6]; // ids
const int64_t nc = src0->ne[0]; // d_state
- const int64_t nr = src0->ne[1]; // d_inner
- const int64_t n_t = src1->ne[1]; // number of tokens per sequence
- const int64_t n_s = src0->ne[2]; // number of sequences in the batch
+ const int64_t nr = src0->ne[1]; // head_dim or 1
+ const int64_t nh = src1->ne[1]; // n_head
+ const int64_t ng = src4->ne[1]; // n_group
+ const int64_t n_t = src1->ne[2]; // number of tokens per sequence
+ const int64_t n_s = src1->ne[3]; // number of sequences in the batch
+
+ const int64_t s_off = ggml_nelements(src1) * sizeof(float);
- GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
+ GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
- // required for the dot product between s and C
- GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
- // required for per-sequence offsets for states
- GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float));
- // required to get correct offset for state destination (i.e. src1->nb[3])
- GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float));
+ GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
const float * src3_d = (const float *) src3->data;
const float * src4_d = (const float *) src4->data;
const float * src5_d = (const float *) src5->data;
+ const int32_t * src6_d = (const int32_t *) src6->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src6->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
- ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0],
- src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1],
- src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream);
+ ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d,
+ src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],
+ src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],
+ s_off, nc, nr, nh, ng, n_t, n_s, stream);
}
typedef struct {
int64_t d_state;
int64_t d_inner;
+ int64_t n_head;
+ int64_t n_group;
int64_t n_seq_tokens;
int64_t n_seqs;
- uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
- uint64_t nb10;
+ uint64_t nb03;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
- uint64_t nb20;
uint64_t nb21;
uint64_t nb22;
- uint64_t nb30;
uint64_t nb31;
- uint64_t nb40;
uint64_t nb41;
uint64_t nb42;
- uint64_t nb50;
+ uint64_t nb43;
uint64_t nb51;
uint64_t nb52;
+ uint64_t nb53;
} ggml_metal_kargs_ssm_scan;
typedef struct {
GGML_METAL_KERNEL_TYPE_NORM,
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
+ GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
struct ggml_tensor * src3 = node->src[3];
struct ggml_tensor * src4 = node->src[4];
struct ggml_tensor * src5 = node->src[5];
+ struct ggml_tensor * src6 = node->src[6];
GGML_ASSERT(src3);
GGML_ASSERT(src4);
GGML_ASSERT(src5);
+ GGML_ASSERT(src6);
size_t offs_src3 = 0;
size_t offs_src4 = 0;
size_t offs_src5 = 0;
+ size_t offs_src6 = 0;
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
+ id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
- const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
+ const int64_t ne30 = src3->ne[0];
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
- const uint64_t nb30 = src3->nb[0];
+ const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30);
const uint64_t nb31 = src3->nb[1];
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
- const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
+ const int64_t ne41 = src4->ne[1];
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
+ const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);
- const uint64_t nb40 = src4->nb[0];
+ const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40);
const uint64_t nb41 = src4->nb[1];
const uint64_t nb42 = src4->nb[2];
+ const uint64_t nb43 = src4->nb[3];
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
+ const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);
- const uint64_t nb50 = src5->nb[0];
+ const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50);
const uint64_t nb51 = src5->nb[1];
const uint64_t nb52 = src5->nb[2];
+ const uint64_t nb53 = src5->nb[3];
+
+ const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);
+
+ const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
const int64_t d_state = ne00;
const int64_t d_inner = ne01;
- const int64_t n_seq_tokens = ne11;
- const int64_t n_seqs = ne02;
+ const int64_t n_head = ne02;
+ const int64_t n_group = ne41;
+ const int64_t n_seq_tokens = ne12;
+ const int64_t n_seqs = ne13;
+
+ id<MTLComputePipelineState> pipeline = nil;
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
+ if (ne30 == 1) {
+ // Mamba-2
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
+ }
ggml_metal_kargs_ssm_scan args = {
- /*.d_state =*/ d_state,
- /*.d_inner =*/ d_inner,
+ /*.d_state =*/ d_state,
+ /*.d_inner =*/ d_inner,
+ /*.n_head =*/ n_head,
+ /*.n_group =*/ n_group,
/*.n_seq_tokens =*/ n_seq_tokens,
- /*.n_seqs =*/ n_seqs,
- /*.nb00 =*/ nb00,
- /*.nb01 =*/ nb01,
- /*.nb02 =*/ nb02,
- /*.nb10 =*/ nb10,
- /*.nb11 =*/ nb11,
- /*.nb12 =*/ nb12,
- /*.nb13 =*/ nb13,
- /*.nb20 =*/ nb20,
- /*.nb21 =*/ nb21,
- /*.nb22 =*/ nb22,
- /*.nb30 =*/ nb30,
- /*.nb31 =*/ nb31,
- /*.nb40 =*/ nb40,
- /*.nb41 =*/ nb41,
- /*.nb42 =*/ nb42,
- /*.nb50 =*/ nb50,
- /*.nb51 =*/ nb51,
- /*.nb52 =*/ nb52,
+ /*.n_seqs =*/ n_seqs,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.nb21 =*/ nb21,
+ /*.nb22 =*/ nb22,
+ /*.nb31 =*/ nb31,
+ /*.nb41 =*/ nb41,
+ /*.nb42 =*/ nb42,
+ /*.nb43 =*/ nb43,
+ /*.nb51 =*/ nb51,
+ /*.nb52 =*/ nb52,
+ /*.nb53 =*/ nb53,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
- [encoder setBytes:&args length:sizeof(args) atIndex:7];
+ [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:7];
+ [encoder setBytes:&args length:sizeof(args) atIndex:8];
- [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ if (ne30 == 1) {
+ // Mamba-2
+ [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } else {
+ GGML_ASSERT(d_inner == 1);
+ [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ }
} break;
case GGML_OP_RWKV_WKV6:
{
x[0] = sumf;
}
-// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
+// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
kernel void kernel_ssm_scan_f32(
device const void * src0,
device const void * src1,
device const void * src3,
device const void * src4,
device const void * src5,
+ device const void * src6,
device float * dst,
constant ggml_metal_kargs_ssm_scan & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t ir = tgpig.x;
- const int64_t i3 = tgpig.y;
+ const int64_t i1 = 0;
+ const int64_t ir = tgpig.x; // current head
+ const int64_t i3 = tgpig.y; // current seq
+
+ const uint64_t nb00 = sizeof(float);
+ const uint64_t nb10 = sizeof(float);
+ const uint64_t nb20 = sizeof(float);
const int64_t nc = args.d_state;
- // const int64_t nr = args.d_inner;
+ const int64_t nr = args.d_inner;
+ const int64_t nh = args.n_head;
+ const int64_t ng = args.n_group;
const int64_t n_t = args.n_seq_tokens;
- // const int64_t n_s = args.n_seqs;
+
+ const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
+
+ device const int32_t * ids = (device const int32_t *) src6;
+
+ device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
+ device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
for (int64_t i2 = 0; i2 < n_t; ++i2) {
- device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02);
- device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12);
- device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22);
- device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
- device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42);
- device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52);
- device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
- device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13);
-
- if (i2 > 0) {
- s0 = s;
- }
-
- // i1 == 0
- float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
- float x_dt = x[0] * dt_soft_plus;
+ device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
+ device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
+ device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
+ device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
+ device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
+
+ const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
+ const float x_dt = x[0] * dt_soft_plus;
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) {
- int64_t i = i0;
- float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
+ const int64_t i = i0 + i1*nc;
+ const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
sumf += state * C[i0];
s[i] = state;
}
y[0] = sumf;
+
+ // recurse
+ s0 = s;
+ }
+}
+
+// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
+// TODO: optimize (e.g. by parallelizing over d_state)
+kernel void kernel_ssm_scan_f32_group(
+ device const void * src0,
+ device const void * src1,
+ device const void * src2,
+ device const void * src3,
+ device const void * src4,
+ device const void * src5,
+ device const void * src6,
+ device float * dst,
+ constant ggml_metal_kargs_ssm_scan & args,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i1 = tgpig.x;
+ const int64_t ir = tgpig.y; // current head
+ const int64_t i3 = tgpig.z; // current seq
+
+ const uint64_t nb00 = sizeof(float);
+ const uint64_t nb10 = sizeof(float);
+ const uint64_t nb20 = sizeof(float);
+
+ const int64_t nc = args.d_state;
+ const int64_t nr = args.d_inner;
+ const int64_t nh = args.n_head;
+ const int64_t ng = args.n_group;
+ const int64_t n_t = args.n_seq_tokens;
+
+ const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
+
+ device const int32_t * ids = (device const int32_t *) src6;
+
+ device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
+ device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
+
+ for (int64_t i2 = 0; i2 < n_t; ++i2) {
+ device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
+ device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
+ device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
+ device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
+ device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
+
+ const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
+ const float x_dt = x[0] * dt_soft_plus;
+ const float dA = exp(dt_soft_plus * A[0]);
+ float sumf = 0.0f;
+
+ for (int64_t i0 = 0; i0 < nc; ++i0) {
+ const int64_t i = i0 + i1*nc;
+ const float state = (s0[i] * dA) + (B[i0] * x_dt);
+ sumf += state * C[i0];
+ s[i] = state;
+ }
+
+ y[0] = sumf;
+
+ // recurse
+ s0 = s;
}
}
const int64_t n_s = sx->ne[2];
// TODO: maybe support other strides than 1?
- // FIXME: this is always true?
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
GGML_ASSERT(sx->ne[1] == d_inner);
GGML_ASSERT(n_t >= 0);
struct ggml_tensor * dt,
struct ggml_tensor * A,
struct ggml_tensor * B,
- struct ggml_tensor * C) {
+ struct ggml_tensor * C,
+ struct ggml_tensor * ids) {
GGML_ASSERT(ggml_is_contiguous(s));
- GGML_ASSERT(ggml_is_contiguous(x));
GGML_ASSERT(ggml_is_contiguous(dt));
GGML_ASSERT(ggml_is_contiguous(A));
- GGML_ASSERT(ggml_is_matrix(A));
- GGML_ASSERT(ggml_is_3d(B));
- GGML_ASSERT(ggml_is_3d(s));
+ GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
- GGML_ASSERT(ggml_are_same_shape(x, dt));
+ GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
+ GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
+ GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
GGML_ASSERT(ggml_are_same_shape(B, C));
+ GGML_ASSERT(ids->type == GGML_TYPE_I32);
{
const int64_t d_state = s->ne[0];
- const int64_t d_inner = s->ne[1];
- const int64_t n_seq_tokens = x->ne[1];
- const int64_t n_seqs = x->ne[2];
-
- GGML_ASSERT(s->ne[2] == n_seqs);
- GGML_ASSERT(x->ne[0] == d_inner);
- GGML_ASSERT(A->ne[0] == d_state);
- GGML_ASSERT(A->ne[1] == d_inner);
+ const int64_t head_dim = x->ne[0];
+ const int64_t n_head = x->ne[1];
+ const int64_t n_seq_tokens = x->ne[2];
+ const int64_t n_seqs = x->ne[3];
+
+ GGML_ASSERT(dt->ne[0] == n_head);
+ GGML_ASSERT(dt->ne[1] == n_seq_tokens);
+ GGML_ASSERT(dt->ne[2] == n_seqs);
+ GGML_ASSERT(ggml_is_3d(dt));
+ GGML_ASSERT(s->ne[1] == head_dim);
+ GGML_ASSERT(s->ne[2] == n_head);
GGML_ASSERT(B->ne[0] == d_state);
- GGML_ASSERT(B->ne[1] == n_seq_tokens);
- GGML_ASSERT(B->ne[2] == n_seqs);
+ GGML_ASSERT(B->ne[2] == n_seq_tokens);
+ GGML_ASSERT(B->ne[3] == n_seqs);
+ GGML_ASSERT(ids->ne[0] == n_seqs);
+ GGML_ASSERT(ggml_is_vector(ids));
+ GGML_ASSERT(A->ne[1] == n_head);
+ GGML_ASSERT(ggml_is_matrix(A));
+
+ if (A->ne[0] != 1) {
+ // Mamba-1 has more granular decay factors
+ GGML_ASSERT(A->ne[0] == d_state);
+ }
}
// concatenated y + ssm_states
- struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
result->op = GGML_OP_SSM_SCAN;
result->src[0] = s;
result->src[3] = A;
result->src[4] = B;
result->src[5] = C;
+ result->src[6] = ids;
return result;
}
const ggml_type type;
const int64_t d_state;
- const int64_t d_inner;
+ const int64_t head_dim;
+ const int64_t n_head;
+ const int64_t n_group;
const int64_t n_seq_tokens;
const int64_t n_seqs;
std::string vars() override {
- return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs);
+ return VARS_TO_STR7(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs);
}
test_ssm_scan(ggml_type type = GGML_TYPE_F32,
- int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
- : type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
+ int64_t d_state = 32,
+ int64_t head_dim = 1, // non-zero for Mamba-2
+ int64_t n_head = 32,
+ int64_t n_group = 1,
+ int64_t n_seq_tokens = 32,
+ int64_t n_seqs = 32)
+ : type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
- ggml_tensor * s = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner, n_seqs, 1 }.data());
- ggml_tensor * x = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
- ggml_tensor * dt = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
- ggml_tensor * A = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner, 1 , 1 }.data());
- ggml_tensor * B = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data());
- ggml_tensor * C = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data());
- ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C);
+ ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, head_dim, n_head, n_seqs);
+ ggml_tensor * x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs);
+ ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs);
+ ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head);
+ ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
+ ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
+ ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs);
+ ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids);
return out;
}
+
+ // similar to test_mul_mat_id
+ void initialize_tensors(ggml_context * ctx) override {
+ std::random_device rd;
+ std::default_random_engine rng(rd());
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ if (t->type == GGML_TYPE_I32) {
+ if (ggml_is_view_op(t->op)) { continue; }
+ // ids
+ for (int64_t r = 0; r < ggml_nrows(t); r++) {
+ std::vector<int32_t> data(t->ne[0]);
+ for (int i = 0; i < t->ne[0]; i++) {
+ data[i] = i;
+ }
+ std::shuffle(data.begin(), data.end(), rng);
+ ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
+ }
+ } else {
+ init_tensor_uniform(t);
+ }
+ }
+ }
};
// GGML_OP_RWKV_WKV6
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
- test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
+ test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
+ test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));