class MambaModel(TextModel):
model_arch = gguf.MODEL_ARCH.MAMBA
+ def __init__(self, dir_model: Path, *args, **kwargs):
+ # Avoid using AutoConfig for hparams
+ hparams = kwargs.pop("hparams", None)
+ if hparams is None:
+ with open(dir_model / "config.json", "r", encoding="utf-8") as f:
+ hparams = json.load(f)
+ super().__init__(dir_model, *args, hparams=hparams, **kwargs)
+
def set_vocab(self):
vocab_size = self.hparams["vocab_size"]
# Round vocab size to next multiple of 8
return [(new_name, data_torch)]
+@ModelBase.register("Mamba2ForCausalLM")
+class Mamba2Model(TextModel):
+ model_arch = gguf.MODEL_ARCH.MAMBA2
+
+ def __init__(self, dir_model: Path, *args, **kwargs):
+ # Avoid using AutoConfig for hparams
+ # It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1
+ hparams = kwargs.pop("hparams", None)
+ if hparams is None:
+ with open(dir_model / "config.json", "r", encoding="utf-8") as f:
+ hparams = json.load(f)
+ super().__init__(dir_model, *args, hparams=hparams, **kwargs)
+
+ def set_vocab(self):
+ vocab_size = self.hparams["vocab_size"]
+ # Round vocab size to next multiple of 16
+ pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16)
+ # pad using ceiling division
+ # ref: https://stackoverflow.com/a/17511341/22827863
+ vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
+ self.hparams["vocab_size"] = vocab_size
+
+ if (self.dir_model / "tokenizer.model").is_file():
+ self._set_vocab_sentencepiece()
+ elif (self.dir_model / "tokenizer.model.v3").is_file():
+ # mamba-codestral
+ raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}")
+ elif (self.dir_model / "tokenizer.json").is_file():
+ self._set_vocab_gpt2()
+ else:
+ # Use the GPT-NeoX tokenizer when no tokenizer files are present
+ self._set_vocab_builtin("gpt-neox", vocab_size)
+
+ def set_gguf_parameters(self):
+ d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
+ d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
+ d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
+ d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
+ head_dim = self.find_hparam(["head_dim"], optional=True) or 64
+ n_group = self.find_hparam(["n_groups"], optional=True) or 1
+
+ rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
+
+ # Fail early for models which don't have a block expansion factor of 2
+ # TODO: does this really matter?
+ assert d_inner == 2 * d_model
+ assert d_inner % head_dim == 0
+
+ self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
+ self.gguf_writer.add_embedding_length(d_model)
+ self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
+ self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
+ self.gguf_writer.add_block_count(self.block_count)
+ self.gguf_writer.add_ssm_conv_kernel(d_conv)
+ self.gguf_writer.add_ssm_inner_size(d_inner)
+ self.gguf_writer.add_ssm_state_size(d_state)
+ self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim)
+ self.gguf_writer.add_ssm_group_count(n_group)
+ self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
+ self.gguf_writer.add_file_type(self.ftype)
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+
+ if name.startswith("model.backbone") or name.startswith("model.lm_head"):
+ # map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2
+ name = name.removeprefix("model.")
+
+ if name.endswith(".dt_bias"):
+ name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
+
+ new_name = self.map_tensor_name(name)
+
+ if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
+ data_torch = data_torch.squeeze()
+ elif any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [
+ gguf.MODEL_TENSOR.SSM_A,
+ gguf.MODEL_TENSOR.SSM_D,
+ ]):
+ # unsqueeze A to use similar shape semantics as Mamba-1
+ # (D is also unsqueezed, but for more straightforward broadcast internally)
+ data_torch = data_torch.reshape((*data_torch.shape, 1))
+ elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
+ d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
+ d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
+ n_group = self.hparams.get("n_groups", 1)
+ data_torch = data_torch.reshape((n_group, d_inner // n_group))
+
+ if name.endswith(".A_log"):
+ logger.debug("A_log --> A ==> " + new_name)
+ data_torch = -torch.exp(data_torch)
+
+ yield (new_name, data_torch)
+
+
@ModelBase.register("CohereForCausalLM")
class CommandR2Model(TextModel):
model_arch = gguf.MODEL_ARCH.COMMAND_R
# maybe we should fallback to text model's arch in that case, since not many models have both
text_config = hparams.get("text_config", {})
vision_config = hparams.get("vision_config", {})
- arch = hparams["architectures"][0]
+ arch = None
+ if (arches := hparams.get("architectures")) is not None and len(arches) > 0:
+ arch = arches[0]
+ elif "ssm_cfg" in hparams:
+ # For non-hf Mamba and Mamba2 models
+ arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM"
+
# if "architectures" is found in the sub-config, use that instead
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
arch = text_config["architectures"][0]
elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None:
arch = vision_config["architectures"][0]
+ if arch is None:
+ raise ValueError("Failed to detect model architecture")
return arch
struct ggml_tensor * dt,
struct ggml_tensor * A,
struct ggml_tensor * B,
- struct ggml_tensor * C);
+ struct ggml_tensor * C,
+ struct ggml_tensor * ids);
// partition into non-overlapping windows with padding if needed
// example:
static void ggml_compute_forward_ssm_scan_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
- const ggml_tensor * src0 = dst->src[0]; // s
- const ggml_tensor * src1 = dst->src[1]; // x
- const ggml_tensor * src2 = dst->src[2]; // dt
- const ggml_tensor * src3 = dst->src[3]; // A
- const ggml_tensor * src4 = dst->src[4]; // B
- const ggml_tensor * src5 = dst->src[5]; // C
+ const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
+ const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
+ const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
+ const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
+ const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
+ const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
+ const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
const int ith = params->ith;
const int nth = params->nth;
- const int64_t nc = src0->ne[0]; // d_state
- const int64_t nr = src0->ne[1]; // d_inner
- const int64_t n_t = src1->ne[1]; // number of tokens per sequence
- const int64_t n_s = src0->ne[2]; // number of sequences in the batch
+ const int64_t nc = src0->ne[0]; // d_state
+ const int64_t nr = src0->ne[1]; // dim
+ const int64_t nh = src1->ne[1]; // n_head
+ const int64_t ng = src4->ne[1];
+ const int64_t nt = src1->ne[2]; // number of tokens per sequence
+ const int64_t ns = src1->ne[3]; // number of sequences in the batch
- GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
+ // can't use ggml_nbytes because src1 is not necessarily contiguous
+ const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
+
+ GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
- // required for the dot product between s and C
- GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
- // required for per-sequence offsets for states
- GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
- // required to get correct offset for state destination (i.e. src1->nb[3])
- GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
+ GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
+ // allows optimizing the modulo since n_group should be a power of 2
+ GGML_ASSERT((ng & -ng) == ng);
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
+ // heads per thread
+ const int dh = (nh + nth - 1)/nth;
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
- const int ir = ir1 - ir0;
+ // head range for this thread
+ const int ih0 = dh*ith;
+ const int ih1 = MIN(ih0 + dh, nh);
+
+ const int32_t * ids = (const int32_t *) src6->data;
+
+ for (int i3 = 0; i3 < ns; ++i3) {
+ const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
+ float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
+
+ for (int i2 = 0; i2 < nt; ++i2) {
+ const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
+ const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
+ const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
+ float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
+
+ if (src3->ne[0] == 1) {
+ // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
+
+ // n_head
+ for (int h = ih0; h < ih1; ++h) {
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
+ const float dA = expf(dt_soft_plus * A[h]);
+
+ // dim
+ for (int i1 = 0; i1 < nr; ++i1) {
+ const int ii = i1 + h*nr;
+ const float x_dt = x[ii] * dt_soft_plus;
+ float sumf = 0.0f;
+#if defined(GGML_SIMD)
+ #if defined(__ARM_FEATURE_SVE)
+ const int ggml_f32_epr = svcntw();
+ const int ggml_f32_step = 1 * ggml_f32_epr;
+
+ const int np = (nc & ~(ggml_f32_step - 1));
+
+ GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
+
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
+
+ for (int i = 0; i < np; i += ggml_f32_step) {
+ // TODO: maybe unroll more?
+ for (int j = 0; j < 1; j++) {
+ GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
+ GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
+ GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
+
+ t0 = GGML_F32_VEC_MUL(t0, adA);
+ t1 = GGML_F32_VEC_MUL(t1, axdt);
+
+ t0 = GGML_F32_VEC_ADD(t0, t1);
+
+ sum = GGML_F32_VEC_FMA(sum, t0, t2);
+
+ GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
+ }
+ }
+
+ sumf = GGML_F32xt_REDUCE_ONE(sum);
+ #else
+ const int np = (nc & ~(GGML_F32_STEP - 1));
+
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
+
+ GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
+ GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
+
+ GGML_F32_VEC ax[GGML_F32_ARR];
+ GGML_F32_VEC ay[GGML_F32_ARR];
+ GGML_F32_VEC az[GGML_F32_ARR];
+
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
+ for (int j = 0; j < GGML_F32_ARR; j++) {
+ ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
+ ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
+ az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
+
+ ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
+ ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
+
+ ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
- #ifdef __ARM_FEATURE_SVE
- for (int i3 = 0; i3 < n_s; ++i3) {
- for (int i2 = 0; i2 < n_t; ++i2) {
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
-
- // use the output as the source for the next token-wise iterations
- if (i2 > 0) { s0 = s; }
-
- // d_inner
- for (int i1 = 0; i1 < ir; ++i1) {
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
- float x_dt = x[i1] * dt_soft_plus;
- svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
- svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
- svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
-
- for (int64_t k = 0; k < nc; k += svcntw()) {
- svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
- svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
- svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
- svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
-
- svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
- t1 = exp_ps_sve(svptrue_b32(), t1);
- svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
-
- vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
- r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
-
- GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
+ sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
+
+ GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
+ }
+ }
+
+ // reduce sum0..sum3 to sum0
+ GGML_F32_VEC_REDUCE(sumf, sum);
+ #endif
+#else
+ const int np = 0;
+#endif
+ // d_state
+ for (int i0 = np; i0 < nc; ++i0) {
+ const int i = i0 + ii*nc;
+ const int ig = i0 + (h & (ng - 1))*nc;
+ // state = prev_state * dA + dB * x
+ const float state = (s0[i] * dA) + (B[ig] * x_dt);
+ // y = rowwise_dotprod(state, C)
+ sumf += state * C[ig];
+ s[i] = state;
+ }
+ y[ii] = sumf;
}
- y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
}
- }
- }
- #else
- for (int i3 = 0; i3 < n_s; ++i3) {
- for (int i2 = 0; i2 < n_t; ++i2) {
- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
-
- // use the output as the source for the next token-wise iterations
- if (i2 > 0) { s0 = s; }
-
- // d_inner
- for (int i1 = 0; i1 < ir; ++i1) {
- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
- float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
- float x_dt = x[i1] * dt_soft_plus;
- float sumf = 0.0f;
- // d_state
- for (int i0 = 0; i0 < nc; ++i0) {
- int i = i0 + i1*nc;
- // state = prev_state * dA + dB * x
- float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
- // y = rowwise_dotprod(state, C)
- sumf += state * C[i0];
- s[i] = state;
+ } else {
+ // Mamba-1 has an element-wise decay factor for the states
+
+ // n_head
+ for (int h = ih0; h < ih1; ++h) {
+ // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
+ const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
+
+ // dim
+ for (int i1 = 0; i1 < nr; ++i1) {
+ const int ii = i1 + h*nr;
+ const float x_dt = x[ii] * dt_soft_plus;
+#if defined(__ARM_FEATURE_SVE)
+ svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
+ svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
+ svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
+
+ // d_state
+ // TODO: what happens when (d_state % svcntw()) != 0?
+ for (int64_t k = 0; k < nc; k += svcntw()) {
+ svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
+ svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
+ svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
+ svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
+
+ svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
+ t1 = exp_ps_sve(svptrue_b32(), t1);
+ svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
+
+ vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
+ r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
+
+ GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
+ }
+ y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
+#else
+ float sumf = 0.0f;
+ // NOTE: can't really use GGML_SIMD here because d_state is usually 16
+ // and also because expf is used within the loop.
+ // d_state
+ for (int i0 = 0; i0 < nc; ++i0) {
+ const int i = i0 + ii*nc;
+ const int ig = i0 + (h & (ng - 1))*nc;
+ // state = prev_state * dA + dB * x
+ const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
+ // y = rowwise_dotprod(state, C)
+ sumf += state * C[ig];
+ s[i] = state;
+ }
+ y[ii] = sumf;
+#endif
}
- y[i1] = sumf;
}
}
+ // use the output as the source when it's not the first token-wise iteration
+ s0 = s;
}
- #endif
+ }
}
void ggml_compute_forward_ssm_scan(
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
-#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c)
+#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
for (int i = 0; i < np; i += ggml_f32_step) {
ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i);
- sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
+ sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
- sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2);
+ sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2);
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
- sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3);
+ sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3);
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
- sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4);
+ sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4);
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
- sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5);
+ sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5);
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
- sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6);
+ sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6);
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
- sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7);
+ sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7);
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
- sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8);
+ sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8);
}
// leftovers
// Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
for (int i = np; i < np2; i += ggml_f32_epr) {
ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i);
- sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
+ sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
}
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
if (np2 < n) {
ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i);
- ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
+ ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
GGML_F32_VEC_STORE(y + i, ay1);
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
- ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2);
+ ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx);
GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
- ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3);
+ ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx);
GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
- ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4);
+ ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx);
GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
- ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5);
+ ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx);
GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
- ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6);
+ ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx);
GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
- ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7);
+ ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx);
GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
- ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8);
+ ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx);
GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
}
for (int i = np; i < np2; i += ggml_f32_epr) {
ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i);
- ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
+ ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
GGML_F32_VEC_STORE(y + i, ay1);
}
case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_LOG:
- case GGML_OP_SSM_SCAN:
- case GGML_OP_SSM_CONV:
return true;
+ case GGML_OP_SSM_SCAN: {
+ if (op->src[3]->ne[0] == 1) {
+ // Mamba2
+ // (kernel only supports d_state == 128 && d_head % 16 == 0)
+ return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
+ } else {
+ // Mamba
+ // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
+ return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
+ }
+ }
+ case GGML_OP_SSM_CONV: {
+ // assumes d_inner % threads == 0
+ return op->src[0]->ne[1] % 128 == 0;
+ }
case GGML_OP_CONT:
return op->src[0]->type != GGML_TYPE_BF16;
case GGML_OP_DIAG_MASK_INF:
__global__ void __launch_bounds__(splitD, 2)
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
- const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2,
- const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
- const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
- float * __restrict__ dst, const int64_t L) {
- GGML_UNUSED(src1_nb0);
- GGML_UNUSED(src2_nb0);
+ const int32_t * __restrict__ src6, float * __restrict__ dst,
+ const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
+ const int src2_nb1, const int src2_nb2, const int src3_nb1,
+ const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
+ const int64_t s_off, const int64_t d_inner, const int64_t L) {
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
- const int bidx = blockIdx.x; // split along B
- const int bidy = blockIdx.y; // split along D
+ const int bidx = blockIdx.x; // split along B (sequences)
+ const int bidy = blockIdx.y; // split along D (d_inner)
const int tid = threadIdx.x;
const int wid = tid / 32;
const int wtid = tid % 32;
float * smem_A = smem;
float * smem_s0 = smem_A + splitD * stride_sA;
- const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
- const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
+ const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
+ const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
- const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2));
- const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2));
- float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
- float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
+ const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
+ const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
+ float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
+ float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
- const int stride_s0 = src0_nb1 / sizeof(float);
- const int stride_x = src1_nb1 / sizeof(float);
+ const int stride_s0 = src0_nb2 / sizeof(float);
+ const int stride_x = src1_nb2 / sizeof(float);
const int stride_dt = src2_nb1 / sizeof(float);
const int stride_A = src3_nb1 / sizeof(float);
- const int stride_B = src4_nb1 / sizeof(float);
- const int stride_C = src5_nb1 / sizeof(float);
+ const int stride_B = src4_nb2 / sizeof(float);
+ const int stride_C = src5_nb2 / sizeof(float);
const int stride_s = stride_s0;
- const int stride_y = stride_x;
+ const int stride_y = d_inner;
// can N not be 16? for example 32?
if (N == 16) {
}
}
+// assumes as many threads as d_state
+template <int splitH, int d_state>
+__global__ void __launch_bounds__(d_state, 1)
+ ssm_scan_f32_group(
+ const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
+ const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
+ const int32_t * __restrict__ src6, float * __restrict__ dst,
+ const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
+ const int src2_nb1, const int src2_nb2, const int src3_nb1,
+ const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
+ const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
+
+ const int head_idx = (blockIdx.x * splitH) / d_head;
+ const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
+ const int seq_idx = blockIdx.y;
+
+ const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
+
+ const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
+ const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
+ const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
+ const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1);
+ const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
+ const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
+ float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
+ float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
+
+ // strides across n_seq_tokens
+ const int stride_x = src1_nb2 / sizeof(float);
+ const int stride_dt = src2_nb1 / sizeof(float);
+ const int stride_B = src4_nb2 / sizeof(float);
+ const int stride_C = src5_nb2 / sizeof(float);
+ const int stride_y = n_head * d_head;
+
+ float state[splitH];
+ // for the parallel accumulation
+ __shared__ float stateC[splitH * d_state];
+
+#pragma unroll
+ for (int j = 0; j < splitH; j++) {
+ state[j] = s0_block[j * d_state + threadIdx.x];
+ }
+
+ for (int64_t i = 0; i < n_tok; i++) {
+ // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
+ // TODO: only calculate B and C once per head group
+ // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
+ float dt_soft_plus = dt_block[i * stride_dt];
+ if (dt_soft_plus <= 20.0f) {
+ dt_soft_plus = log1pf(expf(dt_soft_plus));
+ }
+ const float dA = expf(dt_soft_plus * A_block[0]);
+ const float B = B_block[i * stride_B + threadIdx.x];
+ const float C = C_block[i * stride_C + threadIdx.x];
+
+ // across d_head
+#pragma unroll
+ for (int j = 0; j < splitH; j++) {
+ const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
+
+ state[j] = (state[j] * dA) + (B * x_dt);
+
+ stateC[j * d_state + threadIdx.x] = state[j] * C;
+ }
+
+ __syncthreads();
+
+ // parallel accumulation for stateC
+ // TODO: simplify
+ {
+ static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
+ static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
+
+ // reduce until w matches the warp size
+ // TODO: does this work even when the physical warp size is 64?
+#pragma unroll
+ for (int w = d_state; w > WARP_SIZE; w >>= 1) {
+ // (assuming there are d_state threads)
+#pragma unroll
+ for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
+ // TODO: check for bank conflicts
+ const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
+ stateC[k] += stateC[k + (w >> 1)];
+
+ }
+ __syncthreads();
+ }
+
+ static_assert(splitH >= d_state / WARP_SIZE);
+
+#pragma unroll
+ for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
+ float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
+ y = warp_reduce_sum(y);
+
+ // store the above accumulations
+ if (threadIdx.x % WARP_SIZE == 0) {
+ const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
+ y_block[i * stride_y + k] = y;
+ }
+ }
+ }
+ }
+
+ // write back the state
+#pragma unroll
+ for (int j = 0; j < splitH; j++) {
+ s_block[j * d_state + threadIdx.x] = state[j];
+ }
+}
+
static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
- const float * src4, const float * src5, const int src0_nb1, const int src0_nb2,
- const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
- const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
- const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
- float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B,
+ const float * src4, const float * src5, const int32_t * src6, float * dst,
+ const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,
+ const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,
+ const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
+ const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
cudaStream_t stream) {
const int threads = 128;
- // todo: consider D cannot be divided,does this situation exist?
- GGML_ASSERT(D % threads == 0);
- const dim3 blocks(B, (D + threads - 1) / threads, 1);
- const int smem_size = (threads * (N + 1) * 2) * sizeof(float);
- if (N == 16) {
- ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
- src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0,
- src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
+ // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
+ if (src3_nb1 == sizeof(float)) {
+ // Mamba-2
+ if (d_state == 128) {
+ GGML_ASSERT(d_state % threads == 0);
+ // NOTE: can be any power of two between 4 and 64
+ const int splitH = 16;
+ GGML_ASSERT(head_dim % splitH == 0);
+ const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
+ ssm_scan_f32_group<16, 128><<<blocks, threads, 0, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
+ src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
+ } else {
+ GGML_ABORT("doesn't support d_state!=128.");
+ }
} else {
- GGML_ABORT("doesn't support N!=16.");
+ // Mamba-1
+ GGML_ASSERT(n_head % threads == 0);
+ GGML_ASSERT(head_dim == 1);
+ GGML_ASSERT(n_group == 1);
+ const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
+ const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
+ if (d_state == 16) {
+ ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+ } else {
+ GGML_ABORT("doesn't support d_state!=16.");
+ }
}
}
const struct ggml_tensor * src3 = dst->src[3]; // A
const struct ggml_tensor * src4 = dst->src[4]; // B
const struct ggml_tensor * src5 = dst->src[5]; // C
-
- // const int64_t d_state = src0->ne[0];
- // const int64_t d_inner = src0->ne[1];
- // const int64_t l = src1->ne[1];
- // const int64_t b = src0->ne[2];
+ const struct ggml_tensor * src6 = dst->src[6]; // ids
const int64_t nc = src0->ne[0]; // d_state
- const int64_t nr = src0->ne[1]; // d_inner
- const int64_t n_t = src1->ne[1]; // number of tokens per sequence
- const int64_t n_s = src0->ne[2]; // number of sequences in the batch
+ const int64_t nr = src0->ne[1]; // head_dim or 1
+ const int64_t nh = src1->ne[1]; // n_head
+ const int64_t ng = src4->ne[1]; // n_group
+ const int64_t n_t = src1->ne[2]; // number of tokens per sequence
+ const int64_t n_s = src1->ne[3]; // number of sequences in the batch
+
+ const int64_t s_off = ggml_nelements(src1) * sizeof(float);
- GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
+ GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
- // required for the dot product between s and C
- GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
- // required for per-sequence offsets for states
- GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float));
- // required to get correct offset for state destination (i.e. src1->nb[3])
- GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float));
+ GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
const float * src3_d = (const float *) src3->data;
const float * src4_d = (const float *) src4->data;
const float * src5_d = (const float *) src5->data;
+ const int32_t * src6_d = (const int32_t *) src6->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src6->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
- ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0],
- src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1],
- src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream);
+ ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d,
+ src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],
+ src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],
+ s_off, nc, nr, nh, ng, n_t, n_s, stream);
}
typedef struct {
int64_t d_state;
int64_t d_inner;
+ int64_t n_head;
+ int64_t n_group;
int64_t n_seq_tokens;
int64_t n_seqs;
- uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
- uint64_t nb10;
+ uint64_t nb03;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
- uint64_t nb20;
uint64_t nb21;
uint64_t nb22;
- uint64_t nb30;
uint64_t nb31;
- uint64_t nb40;
uint64_t nb41;
uint64_t nb42;
- uint64_t nb50;
+ uint64_t nb43;
uint64_t nb51;
uint64_t nb52;
+ uint64_t nb53;
} ggml_metal_kargs_ssm_scan;
typedef struct {
GGML_METAL_KERNEL_TYPE_NORM,
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
+ GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
struct ggml_tensor * src3 = node->src[3];
struct ggml_tensor * src4 = node->src[4];
struct ggml_tensor * src5 = node->src[5];
+ struct ggml_tensor * src6 = node->src[6];
GGML_ASSERT(src3);
GGML_ASSERT(src4);
GGML_ASSERT(src5);
+ GGML_ASSERT(src6);
size_t offs_src3 = 0;
size_t offs_src4 = 0;
size_t offs_src5 = 0;
+ size_t offs_src6 = 0;
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
+ id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
- const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
+ const int64_t ne30 = src3->ne[0];
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
- const uint64_t nb30 = src3->nb[0];
+ const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30);
const uint64_t nb31 = src3->nb[1];
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
- const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
+ const int64_t ne41 = src4->ne[1];
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
+ const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);
- const uint64_t nb40 = src4->nb[0];
+ const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40);
const uint64_t nb41 = src4->nb[1];
const uint64_t nb42 = src4->nb[2];
+ const uint64_t nb43 = src4->nb[3];
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
+ const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);
- const uint64_t nb50 = src5->nb[0];
+ const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50);
const uint64_t nb51 = src5->nb[1];
const uint64_t nb52 = src5->nb[2];
+ const uint64_t nb53 = src5->nb[3];
+
+ const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);
+
+ const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
const int64_t d_state = ne00;
const int64_t d_inner = ne01;
- const int64_t n_seq_tokens = ne11;
- const int64_t n_seqs = ne02;
+ const int64_t n_head = ne02;
+ const int64_t n_group = ne41;
+ const int64_t n_seq_tokens = ne12;
+ const int64_t n_seqs = ne13;
+
+ id<MTLComputePipelineState> pipeline = nil;
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
+ if (ne30 == 1) {
+ // Mamba-2
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
+ }
ggml_metal_kargs_ssm_scan args = {
- /*.d_state =*/ d_state,
- /*.d_inner =*/ d_inner,
+ /*.d_state =*/ d_state,
+ /*.d_inner =*/ d_inner,
+ /*.n_head =*/ n_head,
+ /*.n_group =*/ n_group,
/*.n_seq_tokens =*/ n_seq_tokens,
- /*.n_seqs =*/ n_seqs,
- /*.nb00 =*/ nb00,
- /*.nb01 =*/ nb01,
- /*.nb02 =*/ nb02,
- /*.nb10 =*/ nb10,
- /*.nb11 =*/ nb11,
- /*.nb12 =*/ nb12,
- /*.nb13 =*/ nb13,
- /*.nb20 =*/ nb20,
- /*.nb21 =*/ nb21,
- /*.nb22 =*/ nb22,
- /*.nb30 =*/ nb30,
- /*.nb31 =*/ nb31,
- /*.nb40 =*/ nb40,
- /*.nb41 =*/ nb41,
- /*.nb42 =*/ nb42,
- /*.nb50 =*/ nb50,
- /*.nb51 =*/ nb51,
- /*.nb52 =*/ nb52,
+ /*.n_seqs =*/ n_seqs,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.nb21 =*/ nb21,
+ /*.nb22 =*/ nb22,
+ /*.nb31 =*/ nb31,
+ /*.nb41 =*/ nb41,
+ /*.nb42 =*/ nb42,
+ /*.nb43 =*/ nb43,
+ /*.nb51 =*/ nb51,
+ /*.nb52 =*/ nb52,
+ /*.nb53 =*/ nb53,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
- [encoder setBytes:&args length:sizeof(args) atIndex:7];
+ [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:7];
+ [encoder setBytes:&args length:sizeof(args) atIndex:8];
- [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ if (ne30 == 1) {
+ // Mamba-2
+ [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } else {
+ GGML_ASSERT(d_inner == 1);
+ [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ }
} break;
case GGML_OP_RWKV_WKV6:
{
x[0] = sumf;
}
-// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
+// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
kernel void kernel_ssm_scan_f32(
device const void * src0,
device const void * src1,
device const void * src3,
device const void * src4,
device const void * src5,
+ device const void * src6,
device float * dst,
constant ggml_metal_kargs_ssm_scan & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t ir = tgpig.x;
- const int64_t i3 = tgpig.y;
+ const int64_t i1 = 0;
+ const int64_t ir = tgpig.x; // current head
+ const int64_t i3 = tgpig.y; // current seq
+
+ const uint64_t nb00 = sizeof(float);
+ const uint64_t nb10 = sizeof(float);
+ const uint64_t nb20 = sizeof(float);
const int64_t nc = args.d_state;
- // const int64_t nr = args.d_inner;
+ const int64_t nr = args.d_inner;
+ const int64_t nh = args.n_head;
+ const int64_t ng = args.n_group;
const int64_t n_t = args.n_seq_tokens;
- // const int64_t n_s = args.n_seqs;
+
+ const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
+
+ device const int32_t * ids = (device const int32_t *) src6;
+
+ device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
+ device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
for (int64_t i2 = 0; i2 < n_t; ++i2) {
- device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02);
- device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12);
- device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22);
- device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
- device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42);
- device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52);
- device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
- device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13);
-
- if (i2 > 0) {
- s0 = s;
- }
-
- // i1 == 0
- float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
- float x_dt = x[0] * dt_soft_plus;
+ device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
+ device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
+ device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
+ device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
+ device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
+
+ const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
+ const float x_dt = x[0] * dt_soft_plus;
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) {
- int64_t i = i0;
- float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
+ const int64_t i = i0 + i1*nc;
+ const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
sumf += state * C[i0];
s[i] = state;
}
y[0] = sumf;
+
+ // recurse
+ s0 = s;
+ }
+}
+
+// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
+// TODO: optimize (e.g. by parallelizing over d_state)
+kernel void kernel_ssm_scan_f32_group(
+ device const void * src0,
+ device const void * src1,
+ device const void * src2,
+ device const void * src3,
+ device const void * src4,
+ device const void * src5,
+ device const void * src6,
+ device float * dst,
+ constant ggml_metal_kargs_ssm_scan & args,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i1 = tgpig.x;
+ const int64_t ir = tgpig.y; // current head
+ const int64_t i3 = tgpig.z; // current seq
+
+ const uint64_t nb00 = sizeof(float);
+ const uint64_t nb10 = sizeof(float);
+ const uint64_t nb20 = sizeof(float);
+
+ const int64_t nc = args.d_state;
+ const int64_t nr = args.d_inner;
+ const int64_t nh = args.n_head;
+ const int64_t ng = args.n_group;
+ const int64_t n_t = args.n_seq_tokens;
+
+ const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
+
+ device const int32_t * ids = (device const int32_t *) src6;
+
+ device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
+ device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
+
+ for (int64_t i2 = 0; i2 < n_t; ++i2) {
+ device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
+ device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
+ device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
+ device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
+ device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
+
+ const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
+ const float x_dt = x[0] * dt_soft_plus;
+ const float dA = exp(dt_soft_plus * A[0]);
+ float sumf = 0.0f;
+
+ for (int64_t i0 = 0; i0 < nc; ++i0) {
+ const int64_t i = i0 + i1*nc;
+ const float state = (s0[i] * dA) + (B[i0] * x_dt);
+ sumf += state * C[i0];
+ s[i] = state;
+ }
+
+ y[0] = sumf;
+
+ // recurse
+ s0 = s;
}
}
const int64_t n_s = sx->ne[2];
// TODO: maybe support other strides than 1?
- // FIXME: this is always true?
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
GGML_ASSERT(sx->ne[1] == d_inner);
GGML_ASSERT(n_t >= 0);
struct ggml_tensor * dt,
struct ggml_tensor * A,
struct ggml_tensor * B,
- struct ggml_tensor * C) {
+ struct ggml_tensor * C,
+ struct ggml_tensor * ids) {
GGML_ASSERT(ggml_is_contiguous(s));
- GGML_ASSERT(ggml_is_contiguous(x));
GGML_ASSERT(ggml_is_contiguous(dt));
GGML_ASSERT(ggml_is_contiguous(A));
- GGML_ASSERT(ggml_is_matrix(A));
- GGML_ASSERT(ggml_is_3d(B));
- GGML_ASSERT(ggml_is_3d(s));
+ GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
- GGML_ASSERT(ggml_are_same_shape(x, dt));
+ GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
+ GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
+ GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
GGML_ASSERT(ggml_are_same_shape(B, C));
+ GGML_ASSERT(ids->type == GGML_TYPE_I32);
{
const int64_t d_state = s->ne[0];
- const int64_t d_inner = s->ne[1];
- const int64_t n_seq_tokens = x->ne[1];
- const int64_t n_seqs = x->ne[2];
-
- GGML_ASSERT(s->ne[2] == n_seqs);
- GGML_ASSERT(x->ne[0] == d_inner);
- GGML_ASSERT(A->ne[0] == d_state);
- GGML_ASSERT(A->ne[1] == d_inner);
+ const int64_t head_dim = x->ne[0];
+ const int64_t n_head = x->ne[1];
+ const int64_t n_seq_tokens = x->ne[2];
+ const int64_t n_seqs = x->ne[3];
+
+ GGML_ASSERT(dt->ne[0] == n_head);
+ GGML_ASSERT(dt->ne[1] == n_seq_tokens);
+ GGML_ASSERT(dt->ne[2] == n_seqs);
+ GGML_ASSERT(ggml_is_3d(dt));
+ GGML_ASSERT(s->ne[1] == head_dim);
+ GGML_ASSERT(s->ne[2] == n_head);
GGML_ASSERT(B->ne[0] == d_state);
- GGML_ASSERT(B->ne[1] == n_seq_tokens);
- GGML_ASSERT(B->ne[2] == n_seqs);
+ GGML_ASSERT(B->ne[2] == n_seq_tokens);
+ GGML_ASSERT(B->ne[3] == n_seqs);
+ GGML_ASSERT(ids->ne[0] == n_seqs);
+ GGML_ASSERT(ggml_is_vector(ids));
+ GGML_ASSERT(A->ne[1] == n_head);
+ GGML_ASSERT(ggml_is_matrix(A));
+
+ if (A->ne[0] != 1) {
+ // Mamba-1 has more granular decay factors
+ GGML_ASSERT(A->ne[0] == d_state);
+ }
}
// concatenated y + ssm_states
- struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
result->op = GGML_OP_SSM_SCAN;
result->src[0] = s;
result->src[3] = A;
result->src[4] = B;
result->src[5] = C;
+ result->src[6] = ids;
return result;
}
INNER_SIZE = "{arch}.ssm.inner_size"
STATE_SIZE = "{arch}.ssm.state_size"
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
+ GROUP_COUNT = "{arch}.ssm.group_count"
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
class WKV:
RWKV7 = auto()
ARWKV7 = auto()
MAMBA = auto()
+ MAMBA2 = auto()
XVERSE = auto()
COMMAND_R = auto()
COHERE2 = auto()
SSM_DT = auto()
SSM_A = auto()
SSM_D = auto()
+ SSM_NORM = auto()
SSM_OUT = auto()
TIME_MIX_W0 = auto()
TIME_MIX_W1 = auto()
MODEL_ARCH.RWKV7: "rwkv7",
MODEL_ARCH.ARWKV7: "arwkv7",
MODEL_ARCH.MAMBA: "mamba",
+ MODEL_ARCH.MAMBA2: "mamba2",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.COHERE2: "cohere2",
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
+ MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_OUT,
],
+ MODEL_ARCH.MAMBA2: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.SSM_IN,
+ MODEL_TENSOR.SSM_CONV1D,
+ MODEL_TENSOR.SSM_DT,
+ MODEL_TENSOR.SSM_A,
+ MODEL_TENSOR.SSM_D,
+ MODEL_TENSOR.SSM_NORM,
+ MODEL_TENSOR.SSM_OUT,
+ ],
MODEL_ARCH.XVERSE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
+KEY_SSM_GROUP_COUNT = Keys.SSM.GROUP_COUNT
KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS
# tokenization
def add_ssm_time_step_rank(self, value: int) -> None:
self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
+ def add_ssm_group_count(self, value: int) -> None:
+ self.add_uint32(Keys.SSM.GROUP_COUNT.format(arch=self.arch), value)
+
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
"encoder.layers.{bid}.norm2", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
"encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
- "encoder.layer.{bid}.layer_norm_2" # jina-v2-code
+ "encoder.layer.{bid}.layer_norm_2", # jina-v2-code
),
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (
"backbone.layers.{bid}.mixer.D",
),
+ MODEL_TENSOR.SSM_NORM: (
+ "backbone.layers.{bid}.mixer.norm", # mamba2
+ ),
+
MODEL_TENSOR.SSM_OUT: (
"model.layers.{bid}.out_proj",
"backbone.layers.{bid}.mixer.out_proj",
{ LLM_ARCH_GEMMA3N, "gemma3n" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
+ { LLM_ARCH_MAMBA2, "mamba2" },
{ LLM_ARCH_XVERSE, "xverse" },
{ LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_COHERE2, "cohere2" },
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
+ { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" },
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
{ LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
},
},
+ {
+ LLM_ARCH_MAMBA2,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
+ { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
+ { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
+ { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
+ { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
+ { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
+ { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
+ },
+ },
{
LLM_ARCH_XVERSE,
{
{LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
{LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
{LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+ {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
bool llm_arch_is_recurrent(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_MAMBA:
+ case LLM_ARCH_MAMBA2:
case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_RWKV7:
LLM_ARCH_GEMMA3N,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
+ LLM_ARCH_MAMBA2,
LLM_ARCH_XVERSE,
LLM_ARCH_COMMAND_R,
LLM_ARCH_COHERE2,
LLM_KV_SSM_CONV_KERNEL,
LLM_KV_SSM_STATE_SIZE,
LLM_KV_SSM_TIME_STEP_RANK,
+ LLM_KV_SSM_GROUP_COUNT,
LLM_KV_SSM_DT_B_C_RMS,
LLM_KV_WKV_HEAD_SIZE,
LLM_TENSOR_SSM_DT,
LLM_TENSOR_SSM_A,
LLM_TENSOR_SSM_D,
+ LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT,
LLM_TENSOR_TIME_MIX_W0,
LLM_TENSOR_TIME_MIX_W1,
uint32_t kv_head,
uint32_t kv_size,
int32_t rs_zero,
- bool avoid_copies) const {
+ const llm_graph_get_rows_fn & get_state_rows) const {
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
- ggml_tensor * output_states;
-
- if (!avoid_copies) {
- // copy states
- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
- // {state_size, kv_size} -> {state_size, n_seqs}
- output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
- ggml_build_forward_expand(gf, output_states);
- } else {
- // FIXME: make the gathering operation happen before the copy below
- // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
- output_states = states;
- }
+ // copy states
+ // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
+ // {state_size, kv_size} -> {state_size, n_seqs}
+ ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
+ ggml_build_forward_expand(gf, output_states);
// copy extra states which won't be changed further (between n_seqs and n_kv)
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
- bool avoid_copies) const {
- const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
+ const llm_graph_get_rows_fn & get_state_rows) const {
+ const auto * kv_state = static_cast<const llama_memory_recurrent_context *>(mctx);
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
}
ggml_tensor * llm_graph_context::build_rs(
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
- bool avoid_copies) const {
- const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
+ const llm_graph_get_rows_fn & get_state_rows) const {
+ const auto * kv_state = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
}
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
const llm_graph_cb & cb;
};
+// used in build_rs to properly order writes and avoid unnecessary copies
+using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
+
struct llm_graph_context {
const llm_arch arch;
uint32_t kv_head,
uint32_t kv_size,
int32_t rs_zero,
- bool avoid_copies = false) const;
+ const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
llm_graph_input_rs * build_rs_inp() const;
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
- bool avoid_copies = false) const;
+ const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
ggml_tensor * build_rs(
llm_graph_input_mem_hybrid * inp,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
- bool avoid_copies = false) const;
+ const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
ggml_tensor * build_rwkv_token_shift_load(
llm_graph_input_rs * inp,
// TODO: maybe support other convolution strides than 1
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
- return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
+ // Corresponds to Mamba's conv_states size
+ return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
}
uint32_t llama_hparams::n_embd_s() const {
uint32_t ssm_d_inner = 0;
uint32_t ssm_d_state = 0;
uint32_t ssm_dt_rank = 0;
+ uint32_t ssm_n_group = 0;
// for hybrid state space models
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
} break;
case GGML_OP_SSM_CONV:
{
- // FIXME
- ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789);
+ const int64_t n_seq_tokens = 512;
+ const int64_t n_seqs = 3;
+ ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0] - 1 + n_seq_tokens, w->ne[1], n_seqs);
op_tensor = ggml_ssm_conv(ctx, conv_x, w);
} break;
case GGML_OP_SSM_SCAN:
{
- // FIXME
- const int64_t d_state = w->ne[0];
- const int64_t d_inner = w->ne[1];
+ // w is ssm_a, which is used to distinguish Mamba-1 and Mamba-2
+ const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0];
+ const int64_t n_head = w->ne[1];
+ const int64_t head_dim = hparams.ssm_d_inner / n_head;
+ const int64_t n_group = hparams.ssm_n_group ? hparams.ssm_n_group : 1;
const int64_t n_seq_tokens = 512;
- const int64_t n_seqs = 1;
- ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs);
- ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs);
- ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs);
- ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
- ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
- op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C);
+ const int64_t n_seqs = 3;
+ ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs);
+ ggml_tensor * x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, n_head, n_seq_tokens, n_seqs);
+ ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_head, n_seq_tokens, n_seqs);
+ ggml_tensor * B = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs);
+ ggml_tensor * C = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs);
+ ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs);
+ op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C, ids);
} break;
case GGML_OP_RWKV_WKV6:
{
default: type = LLM_TYPE_UNKNOWN;
}
} break;
+ case LLM_ARCH_MAMBA2:
+ {
+ ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
+ ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
+ ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
+ ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
+ ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
+
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+ switch (hparams.n_layer) {
+ case 24:
+ switch (hparams.n_embd) {
+ case 768: type = LLM_TYPE_SMALL; break;
+ default: type = LLM_TYPE_UNKNOWN;
+ } break;
+ case 48:
+ switch (hparams.n_embd) {
+ case 1024: type = LLM_TYPE_MEDIUM; break;
+ case 1536: type = LLM_TYPE_LARGE; break;
+ case 2048: type = LLM_TYPE_XL; break;
+ default: type = LLM_TYPE_UNKNOWN;
+ } break;
+ case 64:
+ switch (hparams.n_embd) {
+ case 2560: type = LLM_TYPE_3B; break;
+ case 4096: type = LLM_TYPE_7B; break;
+ default: type = LLM_TYPE_UNKNOWN;
+ } break;
+ default: type = LLM_TYPE_UNKNOWN;
+ }
+ } break;
case LLM_ARCH_XVERSE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0);
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0);
+ // out_proj
+ layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
+ }
+ } break;
+ case LLM_ARCH_MAMBA2:
+ {
+ const int64_t d_conv = hparams.ssm_d_conv;
+ const int64_t d_inner = hparams.ssm_d_inner;
+ const int64_t d_state = hparams.ssm_d_state;
+ const int64_t n_head = hparams.ssm_dt_rank;
+ const int64_t n_group = hparams.ssm_n_group;
+ const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head;
+
+ // only an expansion factor of 2 is supported for now
+ GGML_ASSERT(2 * n_embd == d_inner);
+
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ // output
+ {
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+ // if output is NULL, init from the input tok embed, duplicated to allow offloading
+ if (output == NULL) {
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
+ }
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = layers[i];
+
+ // norm
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+ layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0);
+
+ layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0);
+ layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, 0);
+
+ layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0);
+
+ // no "weight" suffix for these
+ layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0);
+ layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0);
+
+ layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);
+
// out_proj
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
}
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
+ }
+
+ if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2) {
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
+ LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group);
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
if (!classifier_labels.empty()) {
};
struct llm_build_mamba : public llm_graph_context {
- const llama_model & model;
-
- llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) {
+ llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
ggml_tensor * cur;
ggml_tensor * inpL;
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
- cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il);
+ if (model.arch == LLM_ARCH_MAMBA2) {
+ cur = build_mamba2_layer(rs_inp, gf, cur, model, ubatch, il);
+ } else {
+ cur = build_mamba_layer(rs_inp, gf, cur, model, ubatch, il);
+ }
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
ggml_build_forward_expand(gf, cur);
}
- // TODO: split
ggml_tensor * build_mamba_layer(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * cur,
+ const llama_model & model,
const llama_ubatch & ubatch,
int il) const {
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
const int64_t d_inner = hparams.ssm_d_inner;
const int64_t d_state = hparams.ssm_d_state;
const int64_t dt_rank = hparams.ssm_dt_rank;
+ const int64_t n_head = d_inner;
+ const int64_t head_dim = 1;
const int64_t n_seqs = ubatch.n_seqs;
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
- // (ab)using the KV cache to store the states
- ggml_tensor * conv = build_rs(
- inp, gf, conv_states_all,
- hparams.n_embd_r(), n_seqs);
+ ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
- ggml_tensor * ssm = build_rs(
- inp, gf, ssm_states_all,
- hparams.n_embd_s(), n_seqs);
- ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x);
// split
ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
- ggml_tensor * B = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank);
- ggml_tensor * C = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state));
+ ggml_tensor * B = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank);
+ ggml_tensor * C = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state));
// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
if (ssm_dt_b_c_rms) {
dt = build_lora_mm(model.layers[il].ssm_dt, dt);
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
- // Custom operator to optimize the parallel associative scan
- // as described in the Annex D of the Mamba paper.
- // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
- ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C);
+ cur = x;
+ x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs);
+
+ ggml_tensor * A = model.layers[il].ssm_a;
+
+ // use the states and the indices provided by build_recurrent_state
+ // (this is necessary in order to properly use the states before they are overwritten,
+ // while avoiding to make unnecessary copies of the states)
+ auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
+ ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
+
+ // Custom operator to optimize the parallel associative scan
+ // as described in the Annex D of the Mamba paper.
+ // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
+ return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
+ };
+
+ ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
+
+ // store last states
+ ggml_build_forward_expand(gf,
+ ggml_cpy(ctx0,
+ ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]),
+ ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
+
+ ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[2], x->nb[3], 0);
+
+ // TODO: skip computing output earlier for unused tokens
+
+ y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, model.layers[il].ssm_d));
+ y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
+
+ // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
+ cur = build_lora_mm(model.layers[il].ssm_out, y);
+ }
+
+ // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
+ // cb(cur, "mamba_out", il);
+
+ return cur;
+ }
+
+ ggml_tensor * build_mamba2_layer(
+ llm_graph_input_rs * inp,
+ ggml_cgraph * gf,
+ ggml_tensor * cur,
+ const llama_model & model,
+ const llama_ubatch & ubatch,
+ int il) const {
+ const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
+
+ const auto kv_head = mctx_cur->get_head();
+
+ const int64_t d_conv = hparams.ssm_d_conv;
+ const int64_t d_inner = hparams.ssm_d_inner;
+ const int64_t d_state = hparams.ssm_d_state;
+ const int64_t n_head = hparams.ssm_dt_rank;
+ const int64_t head_dim = d_inner / n_head;
+ const int64_t n_group = hparams.ssm_n_group;
+ const int64_t n_seqs = ubatch.n_seqs;
+
+ const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+ GGML_ASSERT(n_seqs != 0);
+ GGML_ASSERT(ubatch.equal_seqs);
+ GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+ ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+ ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
+
+ ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
+ conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
+
+ // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
+ cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
+
+ // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
+
+ // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
+ ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
+
+ // split the above in three
+ ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0);
+ ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt));
+ ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt));
+
+ // conv
+ {
+ // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
+ ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0);
+
+ // copy last (d_conv - 1) columns back into the state cache
+ ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
+
+ ggml_build_forward_expand(gf,
+ ggml_cpy(ctx0, last_conv,
+ ggml_view_1d(ctx0, conv_states_all,
+ (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
+ kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
+
+ // 1D convolution
+ // The equivalent is to make a self-overlapping view of conv_x
+ // over d_conv columns at each stride in the 3rd dimension,
+ // then element-wise multiply that with the conv1d weight,
+ // then sum the elements of each row,
+ // (the last two steps are a dot product over rows (also doable with mul_mat))
+ // then permute away the ne[0] dimension,
+ // and then you're left with the resulting x tensor.
+ // For simultaneous sequences, all sequences need to have the same length.
+ xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
+
+ // bias
+ xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b);
+
+ xBC = ggml_silu(ctx0, xBC);
+ }
+
+ // ssm
+ {
+ // These correspond to V K Q in SSM/attention duality
+ ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0);
+ ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC));
+ ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC));
+
+ // {n_head, n_seq_tokens, n_seqs}
+ dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
+
+ ggml_tensor * A = model.layers[il].ssm_a;
+
+ // use the states and the indices provided by build_recurrent_state
+ // (this is necessary in order to properly use the states before they are overwritten,
+ // while avoiding to make unnecessary copies of the states)
+ auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
+ ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
+
+ // TODO: use semistructured matrices to implement state-space duality
+ // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
+ return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
+ };
+
+ ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
// store last states
ggml_build_forward_expand(gf,
ggml_cpy(ctx0,
- ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]),
+ ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]),
ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
- ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0);
+ ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0);
// TODO: skip computing output earlier for unused tokens
- // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
+ // grouped RMS norm
+ y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
+ y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
+ y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
+
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
cur = build_lora_mm(model.layers[il].ssm_out, y);
}
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
- //cb(cur, "mamba_out", il);
+ // cb(cur, "mamba_out", il);
return cur;
}
llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
} break;
case LLM_ARCH_MAMBA:
+ case LLM_ARCH_MAMBA2:
{
llm = std::make_unique<llm_build_mamba>(*this, params, gf);
} break;
case LLM_ARCH_REFACT:
case LLM_ARCH_BLOOM:
case LLM_ARCH_MAMBA:
+ case LLM_ARCH_MAMBA2:
case LLM_ARCH_JINA_BERT_V2:
case LLM_ARCH_T5:
case LLM_ARCH_T5ENCODER:
struct ggml_tensor * ffn_sub_norm = nullptr;
struct ggml_tensor * attn_norm_cross = nullptr;
struct ggml_tensor * attn_norm_enc = nullptr;
+ struct ggml_tensor * ssm_norm = nullptr;
// attention
struct ggml_tensor * wq = nullptr;
const ggml_type type;
const int64_t d_state;
- const int64_t d_inner;
+ const int64_t head_dim;
+ const int64_t n_head;
+ const int64_t n_group;
const int64_t n_seq_tokens;
const int64_t n_seqs;
std::string vars() override {
- return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs);
+ return VARS_TO_STR7(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs);
}
test_ssm_scan(ggml_type type = GGML_TYPE_F32,
- int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
- : type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
+ int64_t d_state = 32,
+ int64_t head_dim = 1, // non-zero for Mamba-2
+ int64_t n_head = 32,
+ int64_t n_group = 1,
+ int64_t n_seq_tokens = 32,
+ int64_t n_seqs = 32)
+ : type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
- ggml_tensor * s = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner, n_seqs, 1 }.data());
- ggml_tensor * x = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
- ggml_tensor * dt = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
- ggml_tensor * A = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner, 1 , 1 }.data());
- ggml_tensor * B = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data());
- ggml_tensor * C = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data());
- ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C);
+ ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, head_dim, n_head, n_seqs);
+ ggml_tensor * x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs);
+ ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs);
+ ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head);
+ ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
+ ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
+ ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs);
+ ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids);
return out;
}
+
+ // similar to test_mul_mat_id
+ void initialize_tensors(ggml_context * ctx) override {
+ std::random_device rd;
+ std::default_random_engine rng(rd());
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ if (t->type == GGML_TYPE_I32) {
+ if (ggml_is_view_op(t->op)) { continue; }
+ // ids
+ for (int64_t r = 0; r < ggml_nrows(t); r++) {
+ std::vector<int32_t> data(t->ne[0]);
+ for (int i = 0; i < t->ne[0]; i++) {
+ data[i] = i;
+ }
+ std::shuffle(data.begin(), data.end(), rng);
+ ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
+ }
+ } else {
+ init_tensor_uniform(t);
+ }
+ }
+ }
};
// GGML_OP_RWKV_WKV6
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
- test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
+ test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
+ test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));