]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
llama : initial Mamba-2 support (llama/9126)
authorcompilade <redacted>
Wed, 2 Jul 2025 17:10:24 +0000 (13:10 -0400)
committerGeorgi Gerganov <redacted>
Sat, 12 Jul 2025 13:05:00 +0000 (16:05 +0300)
* llama : initial Mamba-2 support

* ggml : SIMD ggml_ssm_scan for Mamba-2

* ggml : improve ggml_mul speed when masking recurrent states

* llama : support running Mamba-Codestral-7B-v0.1

* llama : fix Mamba-2 conv state saving

* ggml : make the ggml_mul fast broadcast path more consistently formatted

* llama : remove unused variable

* llama : add missing break

* convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present

The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires
workarounds to work correctly.

* llama : avoid redundant state copy for Mamba 1 and 2

* metal : attempt to adapt SSM_SCAN for Mamba-2

* metal : fix SSM_SCAN pipeline scope

* metal : use log and exp instead of log1pf and expf in SSM_SCAN

* metal : remove unused arguments for SSM_SCAN

The max index is 31, so trimming the arguments is necessary.

* metal : add back n_seqs to SSM_SCAN args

Whoops, this is needed for the offset in the concatenated output.

* metal : fix SSM_SCAN state head offset

* metal : fix wrong number of tokens per sequence in SSM_SCAN

* ggml : remove unused fast broadcast path in GGML_MUL

This was initially added because states were masked with ggml_mul,
but this is no longer done and so this "optimisation" is no longer
necessary, or at least not worth the additional code complexity.

* ggml : avoid multiply by D in GGML_OP_SSM_SCAN

This makes the weight buft detection in src/llama.cpp simpler.

* convert : transpose Mamba-2 A, D and reshape SSM_NORM

This breaks existing conversions of Mamba-2 models
to avoid some reshapes.

Not sure if it's a good idea,
but it makes the graph slightly cleaner.

* llama : more appropriate SSM_SCAN and SSM_CONV buft support checks

* convert : fix flake8 lint

* metal : fix confusion between ; and ,

* metal : add missing args for nb references in ssm_scan_f32_group

* metal : single-user mamba2 inference works

* kv-cache : remove const_cast when setting inputs for s_copy

And also fix multi-user inference for recurrent models
by using cell_id instead of i as the kv cell index
when populating s_copy.

* convert : avoid AutoConfig for Mamba and Mamba2 hparams

* kv-cache : allow context shift for recurrent models

* graph : fix recurrent state copies when avoiding copies

Works, but using lambda functions might not be that clean.

* ggml : fix mamba2 ssm scan when compiled with SVE

* ggml-cpu : reorder SVE FMA for consistency with other SIMD arches

* cuda : implement ssm scan for Mamba2

There is still room for improvement, but it works!

* cuda : adapt Mamba1 ssm scan to shape changes from Mamba2

* mamba : fix mismatched new and delete size for llm_build_mamba

Subclasses of llm_graph_context cannot have extra fields,
because the called destructor is not the one from the subclass.
This otherwise would cause problems when runnning Mamba-(1|2) inference
when compiled -DGGML_SANITIZE_ADDRESS=ON

* cuda : graceful fallback for Mamba-1 models with weird embd size

12 files changed:
include/ggml.h
src/ggml-cpu/ops.cpp
src/ggml-cpu/simd-mappings.h
src/ggml-cpu/vec.cpp
src/ggml-cpu/vec.h
src/ggml-cuda/ggml-cuda.cu
src/ggml-cuda/ssm-scan.cu
src/ggml-metal/ggml-metal-impl.h
src/ggml-metal/ggml-metal.m
src/ggml-metal/ggml-metal.metal
src/ggml.c
tests/test-backend-ops.cpp

index a92495dbb0070330ddec35300ceca0ea161a886f..fbe072a3e0c5b170dd6ade3f2ea04db569655125 100644 (file)
@@ -2028,7 +2028,8 @@ extern "C" {
             struct ggml_tensor  * dt,
             struct ggml_tensor  * A,
             struct ggml_tensor  * B,
-            struct ggml_tensor  * C);
+            struct ggml_tensor  * C,
+            struct ggml_tensor  * ids);
 
     // partition into non-overlapping windows with padding if needed
     // example:
index 2ae0721e9d171863529e2f1cf45e07342125945e..2f15341954214d321a4035339960d995446998a6 100644 (file)
@@ -8337,120 +8337,210 @@ void ggml_compute_forward_ssm_conv(
 static void ggml_compute_forward_ssm_scan_f32(
         const ggml_compute_params * params,
         ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0]; // s
-    const ggml_tensor * src1 = dst->src[1]; // x
-    const ggml_tensor * src2 = dst->src[2]; // dt
-    const ggml_tensor * src3 = dst->src[3]; // A
-    const ggml_tensor * src4 = dst->src[4]; // B
-    const ggml_tensor * src5 = dst->src[5]; // C
+    const ggml_tensor * src0 = dst->src[0]; // s  {d_state, dim, n_head, n_seqs+}
+    const ggml_tensor * src1 = dst->src[1]; // x  {dim, n_head, n_seq_tokens, n_seqs}
+    const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
+    const ggml_tensor * src3 = dst->src[3]; // A  {d_state, n_head} or {1, n_head}
+    const ggml_tensor * src4 = dst->src[4]; // B  {d_state, n_group, n_seq_tokens, n_seqs}
+    const ggml_tensor * src5 = dst->src[5]; // C  {d_state, n_group, n_seq_tokens, n_seqs}
+    const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
 
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int64_t nc  = src0->ne[0]; // d_state
-    const int64_t nr  = src0->ne[1]; // d_inner
-    const int64_t n_t = src1->ne[1]; // number of tokens per sequence
-    const int64_t n_s = src0->ne[2]; // number of sequences in the batch
+    const int64_t nc = src0->ne[0]; // d_state
+    const int64_t nr = src0->ne[1]; // dim
+    const int64_t nh = src1->ne[1]; // n_head
+    const int64_t ng = src4->ne[1];
+    const int64_t nt = src1->ne[2]; // number of tokens per sequence
+    const int64_t ns = src1->ne[3]; // number of sequences in the batch
 
-    GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
+    // can't use ggml_nbytes because src1 is not necessarily contiguous
+    const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
+
+    GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
     GGML_ASSERT(src0->nb[0] == sizeof(float));
     GGML_ASSERT(src1->nb[0] == sizeof(float));
     GGML_ASSERT(src2->nb[0] == sizeof(float));
     GGML_ASSERT(src3->nb[0] == sizeof(float));
     GGML_ASSERT(src4->nb[0] == sizeof(float));
     GGML_ASSERT(src5->nb[0] == sizeof(float));
-    // required for the dot product between s and C
-    GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
-    // required for per-sequence offsets for states
-    GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
-    // required to get correct offset for state destination (i.e. src1->nb[3])
-    GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
+    GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
+    // allows optimizing the modulo since n_group should be a power of 2
+    GGML_ASSERT((ng & -ng) == ng);
 
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
+    // heads per thread
+    const int dh = (nh + nth - 1)/nth;
 
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-    const int ir  = ir1 - ir0;
+    // head range for this thread
+    const int ih0 = dh*ith;
+    const int ih1 = MIN(ih0 + dh, nh);
+
+    const int32_t * ids = (const int32_t *) src6->data;
+
+    for (int i3 = 0; i3 < ns; ++i3) {
+        const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
+              float * s  = (      float *) ((      char *) dst->data  + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
+
+        for (int i2 = 0; i2 < nt; ++i2) {
+            const float * x  = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
+            const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
+            const float * A  = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
+            const float * B  = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
+            const float * C  = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
+                  float * y  = (      float *) ((      char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
+
+            if (src3->ne[0] == 1) {
+                // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
+
+                // n_head
+                for (int h = ih0; h < ih1; ++h) {
+                    // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
+                    const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
+                    const float dA = expf(dt_soft_plus * A[h]);
+
+                    // dim
+                    for (int i1 = 0; i1 < nr; ++i1) {
+                        const int ii = i1 + h*nr;
+                        const float x_dt = x[ii] * dt_soft_plus;
+                        float sumf = 0.0f;
+#if defined(GGML_SIMD)
+    #if defined(__ARM_FEATURE_SVE)
+                        const int ggml_f32_epr = svcntw();
+                        const int ggml_f32_step = 1 * ggml_f32_epr;
+
+                        const int np = (nc & ~(ggml_f32_step - 1));
+
+                        GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
+
+                        GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
+                        GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
+
+                        for (int i = 0; i < np; i += ggml_f32_step) {
+                            // TODO: maybe unroll more?
+                            for (int j = 0; j < 1; j++) {
+                                GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
+                                GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
+                                GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
+
+                                t0 = GGML_F32_VEC_MUL(t0, adA);
+                                t1 = GGML_F32_VEC_MUL(t1, axdt);
+
+                                t0 = GGML_F32_VEC_ADD(t0, t1);
+
+                                sum = GGML_F32_VEC_FMA(sum, t0, t2);
+
+                                GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
+                            }
+                        }
+
+                        sumf = GGML_F32xt_REDUCE_ONE(sum);
+    #else
+                        const int np = (nc & ~(GGML_F32_STEP - 1));
+
+                        GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
+
+                        GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
+                        GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
+
+                        GGML_F32_VEC ax[GGML_F32_ARR];
+                        GGML_F32_VEC ay[GGML_F32_ARR];
+                        GGML_F32_VEC az[GGML_F32_ARR];
+
+                        for (int i = 0; i < np; i += GGML_F32_STEP) {
+                            for (int j = 0; j < GGML_F32_ARR; j++) {
+                                ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
+                                ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
+                                az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
+
+                                ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
+                                ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
+
+                                ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
 
-    #ifdef __ARM_FEATURE_SVE
-        for (int i3 = 0; i3 < n_s; ++i3) {
-            for (int i2 = 0; i2 < n_t; ++i2) {
-                const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
-                const float * x  = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
-                const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
-                const float * A  = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
-                const float * B  = (const float *) ((const char *) src4->data +  i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
-                const float * C  = (const float *) ((const char *) src5->data +  i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
-                    float * y  = (      float *) ((      char *) dst->data  + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
-                    float * s  = (      float *) ((      char *) dst->data  + ir0*(src0->nb[1]) + i3*(src0->nb[2]) +     src1->nb[3]);  // {d_state, d_inner, n_s}
-
-                // use the output as the source for the next token-wise iterations
-                if (i2 > 0) { s0 = s; }
-
-                // d_inner
-                for (int i1 = 0; i1 < ir; ++i1) {
-                    float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
-                    float x_dt = x[i1] * dt_soft_plus;
-                    svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
-                    svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
-                    svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
-
-                    for (int64_t k = 0; k < nc; k += svcntw()) {
-                        svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
-                        svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
-                        svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
-                        svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
-
-                        svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
-                        t1 = exp_ps_sve(svptrue_b32(), t1);
-                        svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
-
-                        vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
-                        r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
-
-                        GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
+                                sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
+
+                                GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
+                            }
+                        }
+
+                        // reduce sum0..sum3 to sum0
+                        GGML_F32_VEC_REDUCE(sumf, sum);
+    #endif
+#else
+                        const int np = 0;
+#endif
+                        // d_state
+                        for (int i0 = np; i0 < nc; ++i0) {
+                            const int i = i0 + ii*nc;
+                            const int ig = i0 + (h & (ng - 1))*nc;
+                            // state = prev_state * dA + dB * x
+                            const float state = (s0[i] * dA) + (B[ig] * x_dt);
+                            // y = rowwise_dotprod(state, C)
+                            sumf += state * C[ig];
+                            s[i] = state;
+                        }
+                        y[ii] = sumf;
                     }
-                    y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
                 }
-            }
-        }
-    #else
-        for (int i3 = 0; i3 < n_s; ++i3) {
-            for (int i2 = 0; i2 < n_t; ++i2) {
-                const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
-                const float * x  = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
-                const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
-                const float * A  = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
-                const float * B  = (const float *) ((const char *) src4->data +  i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
-                const float * C  = (const float *) ((const char *) src5->data +  i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
-                    float * y  = (      float *) ((      char *) dst->data  + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
-                    float * s  = (      float *) ((      char *) dst->data  + ir0*(src0->nb[1]) + i3*(src0->nb[2]) +     src1->nb[3]);  // {d_state, d_inner, n_s}
-
-                // use the output as the source for the next token-wise iterations
-                if (i2 > 0) { s0 = s; }
-
-                // d_inner
-                for (int i1 = 0; i1 < ir; ++i1) {
-                    // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
-                    float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
-                    float x_dt = x[i1] * dt_soft_plus;
-                    float sumf = 0.0f;
-                    // d_state
-                    for (int i0 = 0; i0 < nc; ++i0) {
-                        int i = i0 + i1*nc;
-                        // state = prev_state * dA + dB * x
-                        float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
-                        // y = rowwise_dotprod(state, C)
-                        sumf += state * C[i0];
-                        s[i] = state;
+            } else {
+                // Mamba-1 has an element-wise decay factor for the states
+
+                // n_head
+                for (int h = ih0; h < ih1; ++h) {
+                    // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
+                    const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
+
+                    // dim
+                    for (int i1 = 0; i1 < nr; ++i1) {
+                        const int ii = i1 + h*nr;
+                        const float x_dt = x[ii] * dt_soft_plus;
+#if defined(__ARM_FEATURE_SVE)
+                        svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
+                        svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
+                        svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
+
+                        // d_state
+                        // TODO: what happens when (d_state % svcntw()) != 0?
+                        for (int64_t k = 0; k < nc; k += svcntw()) {
+                            svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
+                            svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
+                            svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
+                            svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
+
+                            svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
+                            t1 = exp_ps_sve(svptrue_b32(), t1);
+                            svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
+
+                            vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
+                            r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
+
+                            GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
+                        }
+                        y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
+#else
+                        float sumf = 0.0f;
+                        // NOTE: can't really use GGML_SIMD here because d_state is usually 16
+                        //       and also because expf is used within the loop.
+                        // d_state
+                        for (int i0 = 0; i0 < nc; ++i0) {
+                            const int i = i0 + ii*nc;
+                            const int ig = i0 + (h & (ng - 1))*nc;
+                            // state = prev_state * dA + dB * x
+                            const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
+                            // y = rowwise_dotprod(state, C)
+                            sumf += state * C[ig];
+                            s[i] = state;
+                        }
+                        y[ii] = sumf;
+#endif
                     }
-                    y[i1] = sumf;
                 }
             }
+            // use the output as the source when it's not the first token-wise iteration
+            s0 = s;
         }
-    #endif
+    }
 }
 
 void ggml_compute_forward_ssm_scan(
index b68ac0dd68b40206875d98cec866691673d0bd71..b4ad68c9fd647acbc2b04b86abbc65fa29ea77e2 100644 (file)
@@ -189,7 +189,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
 #define GGML_F32xt_LOAD(...)              GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
 #define GGML_F32xt_STORE_IMPL(pg,a,b)     svst1_f32(pg, a, b)
 #define GGML_F32xt_STORE(...)             GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
-#define GGML_F32xt_FMA_IMPL(pg, a, b, c)  svmad_f32_m(pg, a, b, c)
+#define GGML_F32xt_FMA_IMPL(pg, a, b, c)  svmad_f32_m(pg, b, c, a)
 #define GGML_F32xt_FMA(...)               GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
 #define GGML_F32xt_ADD_IMPL(pg, a, b)     svadd_f32_m(pg, a, b)
 #define GGML_F32xt_ADD(...)               GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
index ed5d7aefc35b324884bed2ef7f17d628415acc6c..a8156011eba2dcddd12e42120b1b4e1788c52a8c 100644 (file)
@@ -37,35 +37,35 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
         for (int i = 0; i < np; i += ggml_f32_step) {
             ax1 = GGML_F32_VEC_LOAD(x + i);
             ay1 = GGML_F32_VEC_LOAD(y + i);
-            sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
+            sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
 
             ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
             ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
-            sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2);
+            sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2);
 
             ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
             ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
-            sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3);
+            sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3);
 
             ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
             ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
-            sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4);
+            sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4);
 
             ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
             ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
-            sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5);
+            sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5);
 
             ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
             ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
-            sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6);
+            sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6);
 
             ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
             ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
-            sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7);
+            sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7);
 
             ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
             ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
-            sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8);
+            sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8);
         }
         // leftovers
         // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
@@ -73,7 +73,7 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
         for (int i = np; i < np2; i += ggml_f32_epr) {
             ax1 = GGML_F32_VEC_LOAD(x + i);
             ay1 = GGML_F32_VEC_LOAD(y + i);
-            sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
+            sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
         }
         // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
         if (np2 < n) {
index d5507d75646d4df51931c082aa1008fa01b09df4..c432c990818226c3d3cd357cab1453c6cdf7e737 100644 (file)
@@ -163,49 +163,49 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
 
             ax1 = GGML_F32_VEC_LOAD(x + i);
             ay1 = GGML_F32_VEC_LOAD(y + i);
-            ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
+            ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
 
             GGML_F32_VEC_STORE(y + i, ay1);
 
             ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
             ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
-            ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2);
+            ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx);
 
             GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
 
             ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
             ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
-            ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3);
+            ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx);
 
             GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
 
             ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
             ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
-            ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4);
+            ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx);
 
             GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
 
             ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
             ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
-            ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5);
+            ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx);
 
             GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
 
             ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
             ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
-            ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6);
+            ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx);
 
             GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
 
             ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
             ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
-            ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7);
+            ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx);
 
             GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
 
             ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
             ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
-            ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8);
+            ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx);
 
             GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
         }
@@ -215,7 +215,7 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
         for (int i = np; i < np2; i += ggml_f32_epr) {
             ax1 = GGML_F32_VEC_LOAD(x + i);
             ay1 = GGML_F32_VEC_LOAD(y + i);
-            ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
+            ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
 
             GGML_F32_VEC_STORE(y + i, ay1);
         }
index 0fa7fb953d815b3d5b8d5787db336c165cae5a06..121e83641b23df227a731a070a45e2e72c90c005 100644 (file)
@@ -3321,9 +3321,22 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_COS:
         case GGML_OP_CLAMP:
         case GGML_OP_LOG:
-        case GGML_OP_SSM_SCAN:
-        case GGML_OP_SSM_CONV:
             return true;
+        case GGML_OP_SSM_SCAN: {
+            if (op->src[3]->ne[0] == 1) {
+                // Mamba2
+                // (kernel only supports d_state == 128 && d_head % 16 == 0)
+                return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
+            } else {
+                // Mamba
+                // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
+                return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
+            }
+        }
+        case GGML_OP_SSM_CONV: {
+            // assumes d_inner % threads == 0
+            return op->src[0]->ne[1] % 128 == 0;
+        }
         case GGML_OP_CONT:
             return op->src[0]->type != GGML_TYPE_BF16;
         case GGML_OP_DIAG_MASK_INF:
index 2d34b836054f8eaed4e3decda8f1fe55854beceb..dc3b1a9a8cbf08bb44263ca63a42875f99717100 100644 (file)
@@ -4,16 +4,15 @@ template <size_t splitD, size_t N>
 __global__ void __launch_bounds__(splitD, 2)
     ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
                  const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
-                 const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2,
-                 const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
-                 const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
-                 float * __restrict__ dst, const int64_t L) {
-    GGML_UNUSED(src1_nb0);
-    GGML_UNUSED(src2_nb0);
+                 const int32_t * __restrict__ src6, float * __restrict__ dst,
+                 const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
+                 const int src2_nb1, const int src2_nb2, const int src3_nb1,
+                 const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
+                 const int64_t s_off, const int64_t d_inner, const int64_t L) {
 
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-    const int bidx = blockIdx.x;  // split along B
-    const int bidy = blockIdx.y;  // split along D
+    const int bidx = blockIdx.x;  // split along B (sequences)
+    const int bidy = blockIdx.y;  // split along D (d_inner)
     const int tid  = threadIdx.x;
     const int wid  = tid / 32;
     const int wtid = tid % 32;
@@ -24,23 +23,23 @@ __global__ void __launch_bounds__(splitD, 2)
     float *                 smem_A     = smem;
     float *                 smem_s0    = smem_A + splitD * stride_sA;
 
-    const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
-    const float * x_block  = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
+    const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
+    const float * x_block  = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
     const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
     const float * A_block  = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
-    const float * B_block  = (const float *) ((const char *) src4 + (bidx * src4_nb2));
-    const float * C_block  = (const float *) ((const char *) src5 + (bidx * src5_nb2));
-    float *       y_block  = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
-    float *       s_block  = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
+    const float * B_block  = (const float *) ((const char *) src4 + (bidx * src4_nb3));
+    const float * C_block  = (const float *) ((const char *) src5 + (bidx * src5_nb3));
+    float *       y_block  = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
+    float *       s_block  = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
 
-    const int stride_s0 = src0_nb1 / sizeof(float);
-    const int stride_x  = src1_nb1 / sizeof(float);
+    const int stride_s0 = src0_nb2 / sizeof(float);
+    const int stride_x  = src1_nb2 / sizeof(float);
     const int stride_dt = src2_nb1 / sizeof(float);
     const int stride_A  = src3_nb1 / sizeof(float);
-    const int stride_B  = src4_nb1 / sizeof(float);
-    const int stride_C  = src5_nb1 / sizeof(float);
+    const int stride_B  = src4_nb2 / sizeof(float);
+    const int stride_C  = src5_nb2 / sizeof(float);
     const int stride_s  = stride_s0;
-    const int stride_y  = stride_x;
+    const int stride_y  = d_inner;
 
     // can N not be 16? for example 32?
     if (N == 16) {
@@ -84,24 +83,156 @@ __global__ void __launch_bounds__(splitD, 2)
     }
 }
 
+// assumes as many threads as d_state
+template <int splitH, int d_state>
+__global__ void __launch_bounds__(d_state, 1)
+    ssm_scan_f32_group(
+        const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
+        const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
+        const int32_t * __restrict__ src6, float * __restrict__ dst,
+        const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
+        const int src2_nb1, const int src2_nb2, const int src3_nb1,
+        const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
+        const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
+
+    const int head_idx = (blockIdx.x * splitH) / d_head;
+    const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
+    const int seq_idx = blockIdx.y;
+
+    const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
+
+    const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
+    const float * x_block  = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
+    const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
+    const float * A_block  = (const float *) ((const char *) src3 + head_idx * src3_nb1);
+    const float * B_block  = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
+    const float * C_block  = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
+    float *       y_block  = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
+    float *       s_block  = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
+
+    // strides across n_seq_tokens
+    const int stride_x  = src1_nb2 / sizeof(float);
+    const int stride_dt = src2_nb1 / sizeof(float);
+    const int stride_B  = src4_nb2 / sizeof(float);
+    const int stride_C  = src5_nb2 / sizeof(float);
+    const int stride_y  = n_head * d_head;
+
+    float state[splitH];
+    // for the parallel accumulation
+    __shared__ float stateC[splitH * d_state];
+
+#pragma unroll
+    for (int j = 0; j < splitH; j++) {
+        state[j] = s0_block[j * d_state + threadIdx.x];
+    }
+
+    for (int64_t i = 0; i < n_tok; i++) {
+        // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
+        // TODO: only calculate B and C once per head group
+        // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
+        float dt_soft_plus = dt_block[i * stride_dt];
+        if (dt_soft_plus <= 20.0f) {
+            dt_soft_plus = log1pf(expf(dt_soft_plus));
+        }
+        const float dA = expf(dt_soft_plus * A_block[0]);
+        const float B = B_block[i * stride_B + threadIdx.x];
+        const float C = C_block[i * stride_C + threadIdx.x];
+
+        // across d_head
+#pragma unroll
+        for (int j = 0; j < splitH; j++) {
+            const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
+
+            state[j] = (state[j] * dA) + (B * x_dt);
+
+            stateC[j * d_state + threadIdx.x] = state[j] * C;
+        }
+
+        __syncthreads();
+
+        // parallel accumulation for stateC
+        // TODO: simplify
+        {
+            static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
+            static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
+
+            // reduce until w matches the warp size
+            // TODO: does this work even when the physical warp size is 64?
+#pragma unroll
+            for (int w = d_state; w > WARP_SIZE; w >>= 1) {
+                // (assuming there are d_state threads)
+#pragma unroll
+                for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
+                    // TODO: check for bank conflicts
+                    const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
+                    stateC[k] += stateC[k + (w >> 1)];
+
+                }
+                __syncthreads();
+            }
+
+            static_assert(splitH >= d_state / WARP_SIZE);
+
+#pragma unroll
+            for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
+                float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
+                y = warp_reduce_sum(y);
+
+                // store the above accumulations
+                if (threadIdx.x % WARP_SIZE == 0) {
+                    const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
+                    y_block[i * stride_y + k] = y;
+                }
+            }
+        }
+    }
+
+    // write back the state
+#pragma unroll
+    for (int j = 0; j < splitH; j++) {
+        s_block[j * d_state + threadIdx.x] = state[j];
+    }
+}
+
 static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
-                              const float * src4, const float * src5, const int src0_nb1, const int src0_nb2,
-                              const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
-                              const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
-                              const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
-                              float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B,
+                              const float * src4, const float * src5, const int32_t * src6, float * dst,
+                              const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,
+                              const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,
+                              const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
+                              const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
                               cudaStream_t stream) {
     const int threads = 128;
-    // todo: consider D cannot be divided,does this situation exist?
-    GGML_ASSERT(D % threads == 0);
-    const dim3 blocks(B, (D + threads - 1) / threads, 1);
-    const int  smem_size = (threads * (N + 1) * 2) * sizeof(float);
-    if (N == 16) {
-        ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
-            src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0,
-            src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
+    // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
+    if (src3_nb1 == sizeof(float)) {
+        // Mamba-2
+        if (d_state == 128) {
+            GGML_ASSERT(d_state % threads == 0);
+            // NOTE: can be any power of two between 4 and 64
+            const int splitH = 16;
+            GGML_ASSERT(head_dim % splitH == 0);
+            const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
+            ssm_scan_f32_group<16, 128><<<blocks, threads, 0, stream>>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
+                    src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
+                    src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
+        } else {
+            GGML_ABORT("doesn't support d_state!=128.");
+        }
     } else {
-        GGML_ABORT("doesn't support N!=16.");
+        // Mamba-1
+        GGML_ASSERT(n_head % threads == 0);
+        GGML_ASSERT(head_dim == 1);
+        GGML_ASSERT(n_group == 1);
+        const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
+        const int  smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
+        if (d_state == 16) {
+            ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
+                src0, src1, src2, src3, src4, src5, src6, dst,
+                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+        } else {
+            GGML_ABORT("doesn't support d_state!=16.");
+        }
     }
 }
 
@@ -112,30 +243,25 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const struct ggml_tensor * src3 = dst->src[3];  // A
     const struct ggml_tensor * src4 = dst->src[4];  // B
     const struct ggml_tensor * src5 = dst->src[5];  // C
-
-    //   const int64_t d_state = src0->ne[0];
-    //   const int64_t d_inner = src0->ne[1];
-    //   const int64_t l = src1->ne[1];
-    //   const int64_t b = src0->ne[2];
+    const struct ggml_tensor * src6 = dst->src[6];  // ids
 
     const int64_t nc  = src0->ne[0];  // d_state
-    const int64_t nr  = src0->ne[1];  // d_inner
-    const int64_t n_t = src1->ne[1];  // number of tokens per sequence
-    const int64_t n_s = src0->ne[2];  // number of sequences in the batch
+    const int64_t nr  = src0->ne[1];  // head_dim or 1
+    const int64_t nh  = src1->ne[1];  // n_head
+    const int64_t ng  = src4->ne[1];  // n_group
+    const int64_t n_t = src1->ne[2];  // number of tokens per sequence
+    const int64_t n_s = src1->ne[3];  // number of sequences in the batch
+
+    const int64_t s_off = ggml_nelements(src1) * sizeof(float);
 
-    GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
+    GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst));
     GGML_ASSERT(src0->nb[0] == sizeof(float));
     GGML_ASSERT(src1->nb[0] == sizeof(float));
     GGML_ASSERT(src2->nb[0] == sizeof(float));
     GGML_ASSERT(src3->nb[0] == sizeof(float));
     GGML_ASSERT(src4->nb[0] == sizeof(float));
     GGML_ASSERT(src5->nb[0] == sizeof(float));
-    // required for the dot product between s and C
-    GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
-    // required for per-sequence offsets for states
-    GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float));
-    // required to get correct offset for state destination (i.e. src1->nb[3])
-    GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float));
+    GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
 
     const float * src0_d = (const float *) src0->data;
     const float * src1_d = (const float *) src1->data;
@@ -143,13 +269,16 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const float * src3_d = (const float *) src3->data;
     const float * src4_d = (const float *) src4->data;
     const float * src5_d = (const float *) src5->data;
+    const int32_t * src6_d = (const int32_t *) src6->data;
     float *       dst_d  = (float *) dst->data;
     cudaStream_t  stream = ctx.stream();
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src6->type == GGML_TYPE_I32);
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
 
-    ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0],
-                      src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1],
-                      src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream);
+    ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d,
+                      src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],
+                      src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],
+                      s_off, nc, nr, nh, ng, n_t, n_s, stream);
 }
index 8c2a971927414de104d797e41eb35a927f953bd8..adb65c0e7b7defea78fcd19de0a544ae86c79788 100644 (file)
@@ -513,26 +513,25 @@ typedef struct {
 typedef struct {
     int64_t  d_state;
     int64_t  d_inner;
+    int64_t  n_head;
+    int64_t  n_group;
     int64_t  n_seq_tokens;
     int64_t  n_seqs;
-    uint64_t nb00;
     uint64_t nb01;
     uint64_t nb02;
-    uint64_t nb10;
+    uint64_t nb03;
     uint64_t nb11;
     uint64_t nb12;
     uint64_t nb13;
-    uint64_t nb20;
     uint64_t nb21;
     uint64_t nb22;
-    uint64_t nb30;
     uint64_t nb31;
-    uint64_t nb40;
     uint64_t nb41;
     uint64_t nb42;
-    uint64_t nb50;
+    uint64_t nb43;
     uint64_t nb51;
     uint64_t nb52;
+    uint64_t nb53;
 } ggml_metal_kargs_ssm_scan;
 
 typedef struct {
index 1a0968ed4d23173fcb635653bd41c910ab2d3af8..de40430ef3565f4781e8661c2960b24fe3fa584e 100644 (file)
@@ -217,6 +217,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_NORM,
     GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
     GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
+    GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
     GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
     GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
@@ -1196,6 +1197,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                            norm,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,                    ssm_conv_f32,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,                    ssm_scan_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,              ssm_scan_f32_group,              true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,                   rwkv_wkv6_f32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,                   rwkv_wkv7_f32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                  mul_mv_f32_f32,                  has_simdgroup_reduction);
@@ -2809,71 +2811,91 @@ static bool ggml_metal_encode_node(
                 struct ggml_tensor * src3 = node->src[3];
                 struct ggml_tensor * src4 = node->src[4];
                 struct ggml_tensor * src5 = node->src[5];
+                struct ggml_tensor * src6 = node->src[6];
 
                 GGML_ASSERT(src3);
                 GGML_ASSERT(src4);
                 GGML_ASSERT(src5);
+                GGML_ASSERT(src6);
 
                 size_t offs_src3 = 0;
                 size_t offs_src4 = 0;
                 size_t offs_src5 = 0;
+                size_t offs_src6 = 0;
 
                 id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
                 id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
                 id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
+                id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
 
-                const int64_t  ne30 = src3->ne[0]; GGML_UNUSED(ne30);
+                const int64_t  ne30 = src3->ne[0];
                 const int64_t  ne31 = src3->ne[1]; GGML_UNUSED(ne31);
 
-                const uint64_t nb30 = src3->nb[0];
+                const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30);
                 const uint64_t nb31 = src3->nb[1];
 
                 const int64_t  ne40 = src4->ne[0]; GGML_UNUSED(ne40);
-                const int64_t  ne41 = src4->ne[1]; GGML_UNUSED(ne41);
+                const int64_t  ne41 = src4->ne[1];
                 const int64_t  ne42 = src4->ne[2]; GGML_UNUSED(ne42);
+                const int64_t  ne43 = src4->ne[3]; GGML_UNUSED(ne43);
 
-                const uint64_t nb40 = src4->nb[0];
+                const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40);
                 const uint64_t nb41 = src4->nb[1];
                 const uint64_t nb42 = src4->nb[2];
+                const uint64_t nb43 = src4->nb[3];
 
                 const int64_t  ne50 = src5->ne[0]; GGML_UNUSED(ne50);
                 const int64_t  ne51 = src5->ne[1]; GGML_UNUSED(ne51);
                 const int64_t  ne52 = src5->ne[2]; GGML_UNUSED(ne52);
+                const int64_t  ne53 = src5->ne[3]; GGML_UNUSED(ne53);
 
-                const uint64_t nb50 = src5->nb[0];
+                const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50);
                 const uint64_t nb51 = src5->nb[1];
                 const uint64_t nb52 = src5->nb[2];
+                const uint64_t nb53 = src5->nb[3];
+
+                const int64_t  ne60 = src6->ne[0]; GGML_UNUSED(ne60);
+
+                const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
 
                 const int64_t d_state      = ne00;
                 const int64_t d_inner      = ne01;
-                const int64_t n_seq_tokens = ne11;
-                const int64_t n_seqs       = ne02;
+                const int64_t n_head       = ne02;
+                const int64_t n_group      = ne41;
+                const int64_t n_seq_tokens = ne12;
+                const int64_t n_seqs       = ne13;
+
+                id<MTLComputePipelineState> pipeline = nil;
 
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
+                if (ne30 == 1) {
+                    // Mamba-2
+                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
+                } else {
+                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
+                }
 
                 ggml_metal_kargs_ssm_scan args = {
-                    /*.d_state =*/ d_state,
-                    /*.d_inner =*/ d_inner,
+                    /*.d_state      =*/ d_state,
+                    /*.d_inner      =*/ d_inner,
+                    /*.n_head       =*/ n_head,
+                    /*.n_group      =*/ n_group,
                     /*.n_seq_tokens =*/ n_seq_tokens,
-                    /*.n_seqs =*/ n_seqs,
-                    /*.nb00 =*/ nb00,
-                    /*.nb01 =*/ nb01,
-                    /*.nb02 =*/ nb02,
-                    /*.nb10 =*/ nb10,
-                    /*.nb11 =*/ nb11,
-                    /*.nb12 =*/ nb12,
-                    /*.nb13 =*/ nb13,
-                    /*.nb20 =*/ nb20,
-                    /*.nb21 =*/ nb21,
-                    /*.nb22 =*/ nb22,
-                    /*.nb30 =*/ nb30,
-                    /*.nb31 =*/ nb31,
-                    /*.nb40 =*/ nb40,
-                    /*.nb41 =*/ nb41,
-                    /*.nb42 =*/ nb42,
-                    /*.nb50 =*/ nb50,
-                    /*.nb51 =*/ nb51,
-                    /*.nb52 =*/ nb52,
+                    /*.n_seqs       =*/ n_seqs,
+                    /*.nb01         =*/ nb01,
+                    /*.nb02         =*/ nb02,
+                    /*.nb03         =*/ nb03,
+                    /*.nb11         =*/ nb11,
+                    /*.nb12         =*/ nb12,
+                    /*.nb13         =*/ nb13,
+                    /*.nb21         =*/ nb21,
+                    /*.nb22         =*/ nb22,
+                    /*.nb31         =*/ nb31,
+                    /*.nb41         =*/ nb41,
+                    /*.nb42         =*/ nb42,
+                    /*.nb43         =*/ nb43,
+                    /*.nb51         =*/ nb51,
+                    /*.nb52         =*/ nb52,
+                    /*.nb53         =*/ nb53,
                 };
 
                 [encoder setComputePipelineState:pipeline];
@@ -2883,10 +2905,17 @@ static bool ggml_metal_encode_node(
                 [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
                 [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
                 [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:6];
-                [encoder setBytes:&args    length:sizeof(args) atIndex:7];
+                [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:7];
+                [encoder setBytes:&args    length:sizeof(args) atIndex:8];
 
-                [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                if (ne30 == 1) {
+                    // Mamba-2
+                    [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } else {
+                    GGML_ASSERT(d_inner == 1);
+                    [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                }
             } break;
         case GGML_OP_RWKV_WKV6:
             {
index dfff669726b94ea9a4ec7d784a250c200f27f22a..a7657e0fc3c36b976d9bff0c3209d026e426a426 100644 (file)
@@ -1596,7 +1596,7 @@ kernel void kernel_ssm_conv_f32(
     x[0] = sumf;
 }
 
-// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
+// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
 kernel void kernel_ssm_scan_f32(
         device const void * src0,
         device const void * src1,
@@ -1604,46 +1604,119 @@ kernel void kernel_ssm_scan_f32(
         device const void * src3,
         device const void * src4,
         device const void * src5,
+        device const void * src6,
         device      float * dst,
         constant ggml_metal_kargs_ssm_scan & args,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t ir = tgpig.x;
-    const int64_t i3 = tgpig.y;
+    const int64_t i1 = 0;
+    const int64_t ir = tgpig.x; // current head
+    const int64_t i3 = tgpig.y; // current seq
+
+    const uint64_t nb00 = sizeof(float);
+    const uint64_t nb10 = sizeof(float);
+    const uint64_t nb20 = sizeof(float);
 
     const int64_t nc  = args.d_state;
-    // const int64_t nr  = args.d_inner;
+    const int64_t nr  = args.d_inner;
+    const int64_t nh  = args.n_head;
+    const int64_t ng  = args.n_group;
     const int64_t n_t = args.n_seq_tokens;
-    // const int64_t n_s = args.n_seqs;
+
+    const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
+
+    device const int32_t * ids = (device const int32_t *) src6;
+
+    device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
+    device       float * s  = (device       float *) ((device       char *) dst  + ir*args.nb02 +      i3*args.nb03 + s_off);
 
     for (int64_t i2 = 0; i2 < n_t; ++i2) {
-        device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02);
-        device const float * x  = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12);
-        device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22);
-        device const float * A  = (device const float *) ((device const char *) src3 + ir*args.nb31);
-        device const float * B  = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42);
-        device const float * C  = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52);
-        device       float * y  = (device       float *) ((device       char *) dst  + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
-        device       float * s  = (device       float *) ((device       char *) dst  + ir*args.nb01 + i3*args.nb02 +    args.nb13);
-
-        if (i2 > 0) {
-            s0 = s;
-        }
-
-        // i1 == 0
-        float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
-        float x_dt = x[0] * dt_soft_plus;
+        device const float * x  = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
+        device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
+        device const float * A  = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
+        device const float * B  = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
+        device const float * C  = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
+        device       float * y  = (device       float *) ((device       char *) dst  + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
+
+        const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
+        const float x_dt = x[0] * dt_soft_plus;
         float sumf = 0.0f;
 
         for (int64_t i0 = 0; i0 < nc; ++i0) {
-            int64_t i = i0;
-            float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
+            const int64_t i = i0 + i1*nc;
+            const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
             sumf += state * C[i0];
             s[i] = state;
         }
 
         y[0] = sumf;
+
+        // recurse
+        s0 = s;
+    }
+}
+
+// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
+// TODO: optimize (e.g. by parallelizing over d_state)
+kernel void kernel_ssm_scan_f32_group(
+        device const void * src0,
+        device const void * src1,
+        device const void * src2,
+        device const void * src3,
+        device const void * src4,
+        device const void * src5,
+        device const void * src6,
+        device      float * dst,
+        constant ggml_metal_kargs_ssm_scan & args,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i1 = tgpig.x;
+    const int64_t ir = tgpig.y; // current head
+    const int64_t i3 = tgpig.z; // current seq
+
+    const uint64_t nb00 = sizeof(float);
+    const uint64_t nb10 = sizeof(float);
+    const uint64_t nb20 = sizeof(float);
+
+    const int64_t nc  = args.d_state;
+    const int64_t nr  = args.d_inner;
+    const int64_t nh  = args.n_head;
+    const int64_t ng  = args.n_group;
+    const int64_t n_t = args.n_seq_tokens;
+
+    const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
+
+    device const int32_t * ids = (device const int32_t *) src6;
+
+    device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
+    device       float * s  = (device       float *) ((device       char *) dst  + ir*args.nb02 +      i3*args.nb03 + s_off);
+
+    for (int64_t i2 = 0; i2 < n_t; ++i2) {
+        device const float * x  = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
+        device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
+        device const float * A  = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
+        device const float * B  = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
+        device const float * C  = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
+        device       float * y  = (device       float *) ((device       char *) dst  + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
+
+        const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
+        const float x_dt = x[0] * dt_soft_plus;
+        const float dA = exp(dt_soft_plus * A[0]);
+        float sumf = 0.0f;
+
+        for (int64_t i0 = 0; i0 < nc; ++i0) {
+            const int64_t i = i0 + i1*nc;
+            const float state = (s0[i] * dA) + (B[i0] * x_dt);
+            sumf += state * C[i0];
+            s[i] = state;
+        }
+
+        y[0] = sumf;
+
+        // recurse
+        s0 = s;
     }
 }
 
index f1245c50c9299013d08d5213a470b398b470ef9f..90c121339fb9014116c81ce828c9a2c39b5fa546 100644 (file)
@@ -4829,7 +4829,6 @@ struct ggml_tensor * ggml_ssm_conv(
     const int64_t n_s     = sx->ne[2];
 
     // TODO: maybe support other strides than 1?
-    // FIXME: this is always true?
     GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
     GGML_ASSERT(sx->ne[1] == d_inner);
     GGML_ASSERT(n_t >= 0);
@@ -4852,36 +4851,49 @@ struct ggml_tensor * ggml_ssm_scan(
         struct ggml_tensor  * dt,
         struct ggml_tensor  * A,
         struct ggml_tensor  * B,
-        struct ggml_tensor  * C) {
+        struct ggml_tensor  * C,
+        struct ggml_tensor  * ids) {
     GGML_ASSERT(ggml_is_contiguous(s));
-    GGML_ASSERT(ggml_is_contiguous(x));
     GGML_ASSERT(ggml_is_contiguous(dt));
     GGML_ASSERT(ggml_is_contiguous(A));
-    GGML_ASSERT(ggml_is_matrix(A));
-    GGML_ASSERT(ggml_is_3d(B));
-    GGML_ASSERT(ggml_is_3d(s));
+    GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
     GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
     GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
-    GGML_ASSERT(ggml_are_same_shape(x, dt));
+    GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
+    GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
+    GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
     GGML_ASSERT(ggml_are_same_shape(B, C));
+    GGML_ASSERT(ids->type == GGML_TYPE_I32);
 
     {
         const int64_t d_state      = s->ne[0];
-        const int64_t d_inner      = s->ne[1];
-        const int64_t n_seq_tokens = x->ne[1];
-        const int64_t n_seqs       = x->ne[2];
-
-        GGML_ASSERT(s->ne[2] == n_seqs);
-        GGML_ASSERT(x->ne[0] == d_inner);
-        GGML_ASSERT(A->ne[0] == d_state);
-        GGML_ASSERT(A->ne[1] == d_inner);
+        const int64_t head_dim     = x->ne[0];
+        const int64_t n_head       = x->ne[1];
+        const int64_t n_seq_tokens = x->ne[2];
+        const int64_t n_seqs       = x->ne[3];
+
+        GGML_ASSERT(dt->ne[0] == n_head);
+        GGML_ASSERT(dt->ne[1] == n_seq_tokens);
+        GGML_ASSERT(dt->ne[2] == n_seqs);
+        GGML_ASSERT(ggml_is_3d(dt));
+        GGML_ASSERT(s->ne[1] == head_dim);
+        GGML_ASSERT(s->ne[2] == n_head);
         GGML_ASSERT(B->ne[0] == d_state);
-        GGML_ASSERT(B->ne[1] == n_seq_tokens);
-        GGML_ASSERT(B->ne[2] == n_seqs);
+        GGML_ASSERT(B->ne[2] == n_seq_tokens);
+        GGML_ASSERT(B->ne[3] == n_seqs);
+        GGML_ASSERT(ids->ne[0] == n_seqs);
+        GGML_ASSERT(ggml_is_vector(ids));
+        GGML_ASSERT(A->ne[1] == n_head);
+        GGML_ASSERT(ggml_is_matrix(A));
+
+        if (A->ne[0] != 1) {
+            // Mamba-1 has more granular decay factors
+            GGML_ASSERT(A->ne[0] == d_state);
+        }
     }
 
     // concatenated y + ssm_states
-    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
 
     result->op   = GGML_OP_SSM_SCAN;
     result->src[0] = s;
@@ -4890,6 +4902,7 @@ struct ggml_tensor * ggml_ssm_scan(
     result->src[3] = A;
     result->src[4] = B;
     result->src[5] = C;
+    result->src[6] = ids;
 
     return result;
 }
index a4f94c2594038edd5863ca5b2750c76c0dc01702..51bbcf92c9054ba755d9c59515029b9d48c618a5 100644 (file)
@@ -2084,28 +2084,58 @@ struct test_ssm_scan : public test_case {
     const ggml_type type;
 
     const int64_t d_state;
-    const int64_t d_inner;
+    const int64_t head_dim;
+    const int64_t n_head;
+    const int64_t n_group;
     const int64_t n_seq_tokens;
     const int64_t n_seqs;
 
     std::string vars() override {
-        return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs);
+        return VARS_TO_STR7(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs);
     }
 
     test_ssm_scan(ggml_type type = GGML_TYPE_F32,
-            int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
-        : type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
+            int64_t d_state = 32,
+            int64_t head_dim = 1, // non-zero for Mamba-2
+            int64_t n_head  = 32,
+            int64_t n_group = 1,
+            int64_t n_seq_tokens = 32,
+            int64_t n_seqs = 32)
+        : type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
-        ggml_tensor * s   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner,      n_seqs, 1 }.data());
-        ggml_tensor * x   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
-        ggml_tensor * dt  = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
-        ggml_tensor * A   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner,      1     , 1 }.data());
-        ggml_tensor * B   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data());
-        ggml_tensor * C   = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data());
-        ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C);
+        ggml_tensor * s   = ggml_new_tensor_4d(ctx, type, d_state,  head_dim,     n_head,       n_seqs);
+        ggml_tensor * x   = ggml_new_tensor_4d(ctx, type, head_dim, n_head,       n_seq_tokens, n_seqs);
+        ggml_tensor * dt  = ggml_new_tensor_3d(ctx, type, n_head,   n_seq_tokens, n_seqs);
+        ggml_tensor * A   = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head);
+        ggml_tensor * B   = ggml_new_tensor_4d(ctx, type, d_state,  n_group,      n_seq_tokens, n_seqs);
+        ggml_tensor * C   = ggml_new_tensor_4d(ctx, type, d_state,  n_group,      n_seq_tokens, n_seqs);
+        ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32,  n_seqs);
+        ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids);
         return out;
     }
+
+    // similar to test_mul_mat_id
+    void initialize_tensors(ggml_context * ctx) override {
+        std::random_device rd;
+        std::default_random_engine rng(rd());
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I32) {
+                if (ggml_is_view_op(t->op)) { continue; }
+                // ids
+                for (int64_t r = 0; r < ggml_nrows(t); r++) {
+                    std::vector<int32_t> data(t->ne[0]);
+                    for (int i = 0; i < t->ne[0]; i++) {
+                        data[i] = i;
+                    }
+                    std::shuffle(data.begin(), data.end(), rng);
+                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
+                }
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
 };
 
 // GGML_OP_RWKV_WKV6
@@ -4498,7 +4528,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
     test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
 
-    test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
+    test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
+    test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2
 
     test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
     test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));