"FLASH_ATTN",
"FLASH_FF",
"FLASH_ATTN_BACK",
+ "SSM_CONV",
+ "SSM_SCAN",
"WIN_PART",
"WIN_UNPART",
"GET_REL_POS",
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
+static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
"flash_attn(x)",
"flash_ff(x)",
"flash_attn_back(x)",
+ "ssm_conv(x)",
+ "ssm_scan(x)",
"win_part(x)",
"win_unpart(x)",
"get_rel_pos(x)",
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
+static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
return result;
}
+// ggml_ssm_conv
+
+struct ggml_tensor * ggml_ssm_conv(
+ struct ggml_context * ctx,
+ struct ggml_tensor * s,
+ struct ggml_tensor * x,
+ struct ggml_tensor * c,
+ struct ggml_tensor * sq) {
+ GGML_ASSERT(ggml_is_3d(s));
+ GGML_ASSERT(ggml_is_matrix(x));
+ GGML_ASSERT(ggml_is_matrix(c));
+ GGML_ASSERT(ggml_is_matrix(sq));
+ GGML_ASSERT(sq->type == GGML_TYPE_I32);
+
+ const int64_t d_conv = c->ne[0];
+ const int64_t d_inner = c->ne[1];
+ const int64_t n_tokens = x->ne[1];
+ const int64_t n_kv = s->ne[2];
+
+ GGML_ASSERT( s->ne[0] == d_conv - 1);
+ GGML_ASSERT( s->ne[1] == d_inner);
+ GGML_ASSERT( x->ne[0] == d_inner);
+ GGML_ASSERT(sq->ne[0] == n_kv);
+ GGML_ASSERT(sq->ne[1] == n_tokens);
+
+ bool is_node = false;
+
+ if (s->grad || x->grad || c->grad || sq->grad) {
+ GGML_ASSERT(false); // TODO: implement
+ is_node = true;
+ }
+
+ // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv}
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv));
+
+ result->op = GGML_OP_SSM_CONV;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = s;
+ result->src[1] = x;
+ result->src[2] = c;
+ result->src[3] = sq;
+
+ return result;
+}
+
+// ggml_ssm_scan
+
+struct ggml_tensor * ggml_ssm_scan(
+ struct ggml_context * ctx,
+ struct ggml_tensor * s,
+ struct ggml_tensor * x,
+ struct ggml_tensor * dt,
+ struct ggml_tensor * A,
+ struct ggml_tensor * B,
+ struct ggml_tensor * C,
+ struct ggml_tensor * sq) {
+ GGML_ASSERT(ggml_is_contiguous(s));
+ GGML_ASSERT(ggml_is_contiguous(x));
+ GGML_ASSERT(ggml_is_contiguous(dt));
+ GGML_ASSERT(ggml_is_contiguous(A));
+ GGML_ASSERT(sq->type == GGML_TYPE_I32);
+ GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
+ GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
+ GGML_ASSERT(ggml_are_same_shape(x, dt));
+
+ {
+ const int64_t d_state = s->ne[0];
+ const int64_t d_inner = s->ne[1];
+ const int64_t n_tokens = x->ne[1];
+
+ GGML_ASSERT(x->ne[0] == d_inner);
+ GGML_ASSERT(A->ne[0] == d_state);
+ GGML_ASSERT(A->ne[1] == d_inner);
+ GGML_ASSERT(B->ne[0] == d_state);
+ GGML_ASSERT(B->ne[1] == n_tokens);
+ GGML_ASSERT(C->ne[0] == d_state);
+ GGML_ASSERT(C->ne[1] == n_tokens);
+ }
+
+ bool is_node = false;
+
+ if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) {
+ GGML_ASSERT(false); // TODO: implement
+ is_node = true;
+ }
+
+ // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
+
+ result->op = GGML_OP_SSM_SCAN;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = s;
+ result->src[1] = x;
+ result->src[2] = dt;
+ result->src[3] = A;
+ result->src[4] = B;
+ result->src[5] = C;
+ result->src[6] = sq;
+
+ return result;
+}
+
// ggml_win_part
struct ggml_tensor * ggml_win_part(
}
}
+// ggml_compute_forward_ssm_conv
+
+static void ggml_compute_forward_ssm_conv_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+ return;
+ }
+
+ const struct ggml_tensor * src0 = dst->src[0]; // conv_state
+ const struct ggml_tensor * src1 = dst->src[1]; // x
+ const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
+ const struct ggml_tensor * src3 = dst->src[3]; // state_seq
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = src2->ne[0]; // d_conv
+ const int nr = src0->ne[1]; // d_inner
+ const int n_t = src1->ne[1]; // n_tokens
+ const int n_kv = src0->ne[2]; // max number of sequences in the batch
+
+ GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst));
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
+ GGML_ASSERT(src2->nb[0] == sizeof(float));
+ GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
+ GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
+ // for use with the destination state offset between sequences
+ GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+ const int ir = ir1 - ir0;
+
+ if (n_kv > 1) {
+ // multiple sequences means it's hard to know when it's the first time a state is read,
+ // so copy them all over to the destination, just to be sure.
+ for (int i3 = 0; i3 < n_kv; ++i3) {
+ float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
+ float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float));
+ // can't use memcpy because of d_conv vs d_conv - 1
+ for (int i1 = 0; i1 < ir; ++i1) {
+ for (int i0 = 0; i0 < nc - 1; ++i0) {
+ // copy s0 to last (d_conv - 1) columns of s
+ s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
+ }
+ }
+ }
+ }
+
+ for (int i2 = 0; i2 < n_t; ++i2) {
+ int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens}
+ float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
+ float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
+ float * s0; // {d_conv - 1, d_inner, n_kv}
+ float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
+ float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
+ int ne0s0;
+
+ GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
+
+ // avoid needing to copy the state for the first token
+ if (i2 == 0) {
+ s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv}
+ ne0s0 = src0->ne[0];
+ } else {
+ // the source is the last (d_conv - 1) columns of the destination
+ s0 = s + 1;
+ ne0s0 = nc;
+ }
+
+ // d_inner
+ for (int i1 = 0; i1 < ir; ++i1) {
+ // shift state left
+ for (int i0 = 0; i0 < nc - 1; ++i0) {
+ s[i0 + i1*nc] = s0[i0 + i1*ne0s0];
+ }
+ // insert x on the last column
+ s[(nc - 1) + i1*nc] = x0[i1];
+ }
+
+ // handle copies when there are multiple output states
+ for (int i3 = 1; i3 < n_kv; ++i3) {
+ int32_t seq = sq[i3];
+ if (0 <= seq && seq < n_kv) {
+ float * s1 = s + (seq - sq[0])*nc*nr;
+ memcpy(s1, s, nc*ir*sizeof(float));
+ } else {
+ // stop at negative or too big seq_ids
+ break;
+ }
+ }
+
+ // it seems a little faster when this is separate from the state shift
+ for (int i1 = 0; i1 < ir; ++i1) {
+ // rowwise dot product
+ float sumf = 0.0f;
+ for (int i0 = 0; i0 < nc; ++i0) {
+ int i = i0 + i1*nc;
+ sumf += s[i] * c[i];
+ }
+ x[i1] = sumf;
+ }
+ }
+}
+
+static void ggml_compute_forward_ssm_conv(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ switch (dst->src[0]->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_ssm_conv_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_ssm_scan
+
+static void ggml_compute_forward_ssm_scan_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+ return;
+ }
+
+ const struct ggml_tensor * src0 = dst->src[0]; // s
+ const struct ggml_tensor * src1 = dst->src[1]; // x
+ const struct ggml_tensor * src2 = dst->src[2]; // dt
+ const struct ggml_tensor * src3 = dst->src[3]; // A
+ const struct ggml_tensor * src4 = dst->src[4]; // B
+ const struct ggml_tensor * src5 = dst->src[5]; // C
+ const struct ggml_tensor * src6 = dst->src[6]; // sq
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t nc = src0->ne[0]; // d_state
+ const int64_t nr = src0->ne[1]; // d_inner
+ const int64_t n_t = src1->ne[1]; // number of tokens in the batch
+ const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch
+
+ GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
+ GGML_ASSERT(src2->nb[0] == sizeof(float));
+ GGML_ASSERT(src3->nb[0] == sizeof(float));
+ GGML_ASSERT(src4->nb[0] == sizeof(float));
+ GGML_ASSERT(src5->nb[0] == sizeof(float));
+ // required for the dot product between s and C, and when copying the states
+ GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
+ // required for per-sequence offsets for states
+ GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
+ // required to get correct offset for state destination (i.e. src1->nb[2])
+ GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+ const int ir = ir1 - ir0;
+
+ if (n_kv > 1) {
+ // it's hard to know if the source states have already been copied
+ // when there are multiple, so copy them already.
+ for (int i3 = 0; i3 < n_kv; ++i3) {
+ float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
+ float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]);
+ memcpy(s, s0, nc*ir*sizeof(float));
+ }
+ }
+
+ for (int i2 = 0; i2 < n_t; ++i2) {
+ int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens}
+ float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
+ float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv}
+ float * s0;
+ float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
+ float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
+ float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
+ float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
+ float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
+
+ GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
+
+ // avoid needing to copy the state for the first token
+ if (i2 == 0) {
+ s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv}
+ } else {
+ // otherwise the source is the same as the destination
+ s0 = s;
+ }
+
+ // d_inner
+ for (int i1 = 0; i1 < ir; ++i1) {
+ // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
+ float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
+ float x_dt = x[i1] * dt_soft_plus;
+ float sumf = 0.0f;
+ // d_state
+ for (int i0 = 0; i0 < nc; ++i0) {
+ int i = i0 + i1*nc;
+ // state = prev_state * dA + dB * x
+ float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
+ // y = rowwise_dotprod(state, C)
+ sumf += state * C[i0];
+ s[i] = state;
+ }
+ y[i1] = sumf;
+ }
+
+ // handle copies when there are multiple output states
+ for (int i3 = 1; i3 < n_kv; ++i3) {
+ int32_t seq = sq[i3];
+ if (0 <= seq && seq < n_kv) {
+ float * s1 = s + (seq - sq[0])*nc*nr;
+ memcpy(s1, s, nc*ir*sizeof(float));
+ } else {
+ // stop at negative or too big seq_ids
+ break;
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_ssm_scan(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ switch (dst->src[0]->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_ssm_scan_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
// ggml_compute_forward_win_part
static void ggml_compute_forward_win_part_f32(
bool masked = t != 0;
ggml_compute_forward_flash_attn_back(params, masked, tensor);
} break;
+ case GGML_OP_SSM_CONV:
+ {
+ ggml_compute_forward_ssm_conv(params, tensor);
+ } break;
+ case GGML_OP_SSM_SCAN:
+ {
+ ggml_compute_forward_ssm_scan(params, tensor);
+ } break;
case GGML_OP_WIN_PART:
{
ggml_compute_forward_win_part(params, tensor);
{
GGML_ASSERT(false); // not supported
} break;
+ case GGML_OP_SSM_CONV:
+ case GGML_OP_SSM_SCAN:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART:
case GGML_OP_UNARY:
{
n_tasks = n_threads;
} break;
+ case GGML_OP_SSM_CONV:
+ case GGML_OP_SSM_SCAN:
+ {
+ n_tasks = n_threads;
+ } break;
case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART:
case GGML_OP_GET_REL_POS:
LLM_ARCH_MINICPM,
LLM_ARCH_GEMMA,
LLM_ARCH_STARCODER2,
+ LLM_ARCH_MAMBA,
LLM_ARCH_UNKNOWN,
};
{ LLM_ARCH_MINICPM, "minicpm" },
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
+ { LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
LLM_KV_ROPE_SCALING_FINETUNED,
+ LLM_KV_SSM_INNER_SIZE,
+ LLM_KV_SSM_CONV_KERNEL,
+ LLM_KV_SSM_STATE_SIZE,
+ LLM_KV_SSM_TIME_STEP_RANK,
+
LLM_KV_TOKENIZER_MODEL,
LLM_KV_TOKENIZER_LIST,
LLM_KV_TOKENIZER_TOKEN_TYPE,
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
+ { LLM_KV_SSM_CONV_KERNEL, "%s.ssm.conv_kernel" },
+ { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
+ { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
+ { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
+
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
{ LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" },
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_LAYER_OUT_NORM,
+ LLM_TENSOR_SSM_IN,
+ LLM_TENSOR_SSM_CONV1D,
+ LLM_TENSOR_SSM_X,
+ LLM_TENSOR_SSM_DT,
+ LLM_TENSOR_SSM_A,
+ LLM_TENSOR_SSM_D,
+ LLM_TENSOR_SSM_OUT,
};
static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
+ {
+ LLM_ARCH_MAMBA,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
+ { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
+ { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
+ { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
+ { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
+ { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
+ { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
+ },
+ },
{
LLM_ARCH_UNKNOWN,
{
float rope_freq_scale_train;
uint32_t n_yarn_orig_ctx;
+ // for State Space Models
+ uint32_t ssm_d_conv = 0;
+ uint32_t ssm_d_inner = 0;
+ uint32_t ssm_d_state = 0;
+ uint32_t ssm_dt_rank = 0;
+
float f_clamp_kqv = 0.0f;
float f_max_alibi_bias = 0.0f;
if (this->rope_finetuned != other.rope_finetuned) return true;
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
+ if (this->ssm_d_conv != other.ssm_d_conv) return true;
+ if (this->ssm_d_inner != other.ssm_d_inner) return true;
+ if (this->ssm_d_state != other.ssm_d_state) return true;
+ if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
+
const float EPSILON = 1e-9f;
if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true;
}
uint32_t n_gqa() const {
+ if (n_head_kv == 0) {
+ return 0;
+ }
return n_head/n_head_kv;
}
uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads
return n_embd_head_v * n_head_kv;
}
+
+ uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings
+ // corresponds to Mamba's conv_states size
+ // TODO: maybe support other convolution strides than 1
+ // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
+ return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
+ }
+
+ uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
+ // corresponds to Mamba's ssm_states size
+ return ssm_d_state * ssm_d_inner;
+ }
};
struct llama_cparams {
struct ggml_tensor * ffn_down_b; // b2
struct ggml_tensor * ffn_up_b; // b3
struct ggml_tensor * ffn_act;
+
+ // mamba proj
+ struct ggml_tensor * ssm_in;
+ struct ggml_tensor * ssm_x;
+ struct ggml_tensor * ssm_dt;
+ struct ggml_tensor * ssm_out;
+
+ // mamba
+ struct ggml_tensor * ssm_conv1d;
+ struct ggml_tensor * ssm_a;
+ struct ggml_tensor * ssm_d;
+
+ // mamba bias
+ struct ggml_tensor * ssm_conv1d_b;
+ struct ggml_tensor * ssm_dt_b;
};
struct llama_kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
+ int32_t src = 0; // used by recurrent state models to copy states
std::set<llama_seq_id> seq_id;
struct llama_kv_cache {
bool has_shift = false;
bool do_defrag = false;
+ bool do_copy = false;
+ // with recurrent state models, a cell can hold the state for more than one past token
+ bool recurrent = false;
// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_internal also uses it, so it
struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch]
- struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch]
- struct ggml_tensor * inp_KQ_pos; // F32 [n_ctx]
- struct ggml_tensor * inp_K_shift; // I32 [n_ctx]
+ struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
+ struct ggml_tensor * inp_KQ_pos; // F32 [kv_size]
+ struct ggml_tensor * inp_K_shift; // I32 [kv_size]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch]
+ struct ggml_tensor * inp_s_copy; // I32 [kv_size]
+ struct ggml_tensor * inp_s_mask; // F32 [kv_size]
+ struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
#ifdef GGML_USE_MPI
ggml_mpi_context * ctx_mpi = NULL;
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
- uint32_t n_ctx,
+ uint32_t kv_size,
bool offload) {
const struct llama_hparams & hparams = model.hparams;
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
const int64_t n_layer = hparams.n_layer;
cache.has_shift = false;
+ // TODO: find a nicer way to add other recurrent model architectures
+ cache.recurrent = model.arch == LLM_ARCH_MAMBA;
+
+ // TODO: support mixed reccurent Transformer architectues
+ // NOTE: (!a || b) is a logical implication (a -> b)
+ GGML_ASSERT(!cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_s());
+ GGML_ASSERT(!cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_s());
+ GGML_ASSERT( cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_gqa());
+ GGML_ASSERT( cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_gqa());
+
cache.head = 0;
- cache.size = n_ctx;
+ cache.size = kv_size;
cache.used = 0;
cache.type_k = type_k;
cache.type_v = type_v;
cache.cells.clear();
- cache.cells.resize(n_ctx);
+ cache.cells.resize(kv_size);
+
+ if (cache.recurrent) {
+ // init state copy sources
+ for (uint32_t i = 0; i < cache.size; ++i) {
+ cache.cells[i].src = i;
+ }
+ }
#ifdef GGML_USE_CLBLAST
offload = false;
for (int i = 0; i < (int) n_layer; i++) {
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
- ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*n_ctx);
- ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*n_ctx);
+ ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
+ ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
cache.k_l.push_back(k);
const uint32_t n_ctx = cache.size;
const uint32_t n_tokens = batch.n_tokens;
+ if (cache.recurrent) {
+ // For recurrent state architectures (like Mamba),
+ // each KV cache cell can store the state for a whole sequence.
+
+ llama_seq_id min = cache.size - 1;
+ llama_seq_id max = 0;
+
+ for (uint32_t i = 0; i < n_tokens; ++i) {
+ for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) {
+ llama_seq_id seq_id = batch.seq_id[i][j];
+ // make sure it's a valid seq_id
+ if ((uint32_t) seq_id < cache.size) {
+ if (seq_id > max) {
+ max = seq_id;
+ }
+ if (seq_id < min) {
+ min = seq_id;
+ }
+ // Assuming the tokens are in-order
+ if (batch.pos[i] != cache.cells[seq_id].pos + 1) {
+ // What should happen when the pos backtracks or skips a value?
+ // Clearing the state mid-batch would require special-casing which isn't done.
+ LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
+ __func__, batch.pos[i], cache.cells[seq_id].pos, seq_id);
+ }
+ if (cache.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) {
+ cache.used += 1;
+ }
+ cache.cells[seq_id].pos = batch.pos[i];
+ // NOTE: seq_ids are not inserted here; they are handled when the input tensors are set
+ } else {
+ // too big seq_id
+ // TODO: would it be possible to resize the KV cache size instead?
+ LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
+ return false;
+ }
+ }
+ }
+
+ // allow getting the range of used cells, from head to head + n
+ cache.head = min;
+ cache.n = max - min + 1;
+
+ // sanity check
+ return max >= min;
+ }
+ // otherwise, one cell per token.
+
if (n_tokens > n_ctx) {
LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
return false;
cache.used = 0;
}
-static void llama_kv_cache_seq_rm(
+static bool llama_kv_cache_seq_rm(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
+ // models like Mamba can't have a state partially erased
+ if (cache.recurrent) {
+ if (seq_id >= (int64_t) cache.size) {
+ // could be fatal
+ return false;
+ }
+ if (0 <= seq_id) {
+ // partial intersection is invalid
+ if ((0 < p0 && p0 <= cache.cells[seq_id].pos) || (0 < p1 && p1 <= cache.cells[seq_id].pos)) {
+ return false;
+ }
+ } else {
+ // seq_id is negative, then the range should include everything or nothing
+ if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
+ return false;
+ }
+ }
+ }
+
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
if (seq_id < 0) {
// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
+
+ return true;
}
static void llama_kv_cache_seq_cp(
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
+ if (cache.recurrent) {
+ if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
+ seq_id_src = cache.cells[seq_id_src].src;
+ GGML_ASSERT((uint32_t) seq_id_src < cache.size);
+ // intent to "copy from"
+ // supports copy chains thanks to taking the source of the source
+ cache.cells[seq_id_dst].src = seq_id_src;
+
+ // preserve the "keep or clear" status of the copied sequence
+ if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
+ cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
+ } else {
+ cache.cells[seq_id_dst].seq_id.erase(seq_id_dst);
+ }
+
+ cache.do_copy = true;
+
+ cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
+ }
+ return;
+ }
+ // otherwise, this is the KV cache of a Transformer-like model
+
cache.head = 0;
for (uint32_t i = 0; i < cache.size; ++i) {
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
+ if (cache.recurrent) {
+ // for Mamba-like models, only the pos needs to be shifted
+ if (0 <= seq_id && seq_id < (int64_t) cache.size) {
+ llama_kv_cell & cell = cache.cells[seq_id];
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
+ cell.pos += delta;
+ }
+ }
+ return;
+ }
+
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.has_shift = true;
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
+ if (cache.recurrent) {
+ // for Mamba-like models, only the pos needs to be changed
+ if (0 <= seq_id && seq_id < (int64_t) cache.size) {
+ llama_kv_cell & cell = cache.cells[seq_id];
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
+ cell.pos /= d;
+ }
+ }
+ return;
+ }
+
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.has_shift = true;
// sanity check for n_rot (optional)
{
- hparams.n_rot = hparams.n_embd / hparams.n_head;
+ hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
// gpt-j n_rot = rotary_dim
}
- hparams.n_embd_head_k = hparams.n_embd / hparams.n_head;
+ hparams.n_embd_head_k = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
- hparams.n_embd_head_v = hparams.n_embd / hparams.n_head;
+ hparams.n_embd_head_v = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
// arch-specific KVs
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_MAMBA:
+ {
+ ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
+ ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
+ ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
+ ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
+
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+ switch (hparams.n_layer) {
+ case 24:
+ switch (hparams.n_embd) {
+ case 768: model.type = e_model::MODEL_SMALL; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ } break;
+ case 48:
+ switch (hparams.n_embd) {
+ case 1024: model.type = e_model::MODEL_MEDIUM; break;
+ case 1536: model.type = e_model::MODEL_LARGE; break;
+ case 2048: model.type = e_model::MODEL_XL; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ } break;
+ case 64:
+ switch (hparams.n_embd) {
+ case 2560: model.type = e_model::MODEL_3B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ } break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ } break;
default: (void)0;
}
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx);
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
+ LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
+ LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
+ LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
+ LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
if (ml.n_elements >= 1e12) {
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff});
}
} break;
+ case LLM_ARCH_MAMBA:
+ {
+ const int64_t d_conv = hparams.ssm_d_conv;
+ const int64_t d_inner = hparams.ssm_d_inner;
+ const int64_t d_state = hparams.ssm_d_state;
+ const int64_t dt_rank = hparams.ssm_dt_rank;
+ // only an expansion factor of 2 is supported for now
+ GGML_ASSERT(2 * n_embd == d_inner);
+
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+ // output
+ {
+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
+ // if output is NULL, init from the input tok embed, duplicated to allow offloading
+ if (model.output == NULL) {
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+ ml.n_created--; // artificial tensor
+ ml.size_data += ggml_nbytes(model.output);
+ }
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ ggml_context * ctx_layer = ctx_for_layer(i);
+ ggml_context * ctx_split = ctx_for_layer_split(i);
+
+ auto & layer = model.layers[i];
+
+ // norm
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+ layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner});
+
+ layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner});
+ layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner});
+
+ layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state});
+
+ layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner});
+ layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner});
+
+ // no "weight" suffix for these
+ layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner});
+ layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner});
+
+ // out_proj
+ layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
+ }
+ } break;
default:
throw std::runtime_error("unknown architecture");
}
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+ GGML_ASSERT(kv.size == n_ctx);
+
// compute the transposed [n_tokens, n_embd] V matrix
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
cb(kq, "kq_soft_max_ext", il);
}
+ GGML_ASSERT(kv.size == n_ctx);
+
// split cached v into n_head heads
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
norm_eps (hparams.f_norm_eps),
norm_rms_eps (hparams.f_norm_rms_eps),
n_tokens (batch.n_tokens),
- n_kv (worst_case ? n_ctx : kv_self.n),
- kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
+ n_kv (worst_case ? kv_self.size : kv_self.n),
+ kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx),
pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type),
struct ggml_cgraph * build_k_shift() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+ GGML_ASSERT(kv_self.size == n_ctx);
+
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * tmp =
// we rotate only the first n_rot dimensions
return gf;
}
+ struct ggml_cgraph * build_s_copy() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+ GGML_ASSERT(kv_self.recurrent);
+
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
+ struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
+
+ conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
+ ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy);
+
+ // TODO: name the intermediate tensors with cb()
+
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
+ }
+
+ return gf;
+ }
+
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
return gf;
}
+
+ struct ggml_cgraph * build_mamba() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+ const int64_t d_model = n_embd;
+ const int64_t d_conv = hparams.ssm_d_conv;
+ const int64_t d_inner = hparams.ssm_d_inner;
+ GGML_ASSERT(2 * d_model == d_inner);
+ const int64_t d_state = hparams.ssm_d_state;
+ const int64_t dt_rank = hparams.ssm_dt_rank;
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ // {n_embd, n_tokens}
+ inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
+ cb(inpL, "inp_embd", -1);
+
+ struct ggml_tensor * state_mask = ggml_view_2d(ctx0, lctx.inp_s_mask, 1, n_kv, lctx.inp_s_mask->nb[0], 0);
+ struct ggml_tensor * state_seq = ggml_view_2d(ctx0, lctx.inp_s_seq, n_kv, n_tokens, n_kv*ggml_element_size(lctx.inp_s_seq), 0);
+
+ for (int il = 0; il < n_layer; ++il) {
+ // (ab)using the KV cache to store the states
+ struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
+ struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
+
+ // clear states of sequences which are starting at the beginning of this batch
+ {
+ conv_states = ggml_mul(ctx0,
+ ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]),
+ state_mask);
+ ssm_states = ggml_mul(ctx0,
+ ggml_view_2d(ctx0, ssm_states, ssm_states->ne[0], n_kv, ssm_states->nb[1], kv_head*ssm_states->nb[1]),
+ state_mask);
+ }
+
+ conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv);
+ ssm_states = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv);
+
+ // norm
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm", il);
+
+ // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens}
+ struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur);
+ // split the above in two
+ // => {d_inner, n_tokens}
+ struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
+ struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
+
+ // conv
+ {
+ // Custom operator which is needed only to ease simultaneous sequence processing.
+ // For a single sequence, the equivalent is to concatenate the columns of conv_states and x,
+ // then make a self-overlapping view of that over d_conv columns at each stride in the 3rd dimension,
+ // then element-wise multiply that with the conv1d weigth,
+ // then sum the elements of each row,
+ // (the last two steps are a dot product over rows (also doable with mul_mat))
+ // then permute away the ne[0] dimension,
+ // and then you're left with the resulting x tensor.
+ // The new conv_states is the last (d_conv - 1) columns
+ // of the last 3rd dimensional "layer" of the self-overlapping view.
+ // For simultaneous sequences, it's more complicated.
+ struct ggml_tensor * x_conv = ggml_ssm_conv(ctx0, conv_states, x, model.layers[il].ssm_conv1d, state_seq);
+
+ // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache
+ ggml_build_forward_expand(gf,
+ ggml_cpy(ctx0,
+ ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)),
+ ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_self.head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv))));
+
+ // extract x from x_conv
+ x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0);
+
+ // bias
+ x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
+
+ x = ggml_silu(ctx0, x);
+ }
+
+ // ssm
+ {
+ // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens}
+ struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x);
+ // split
+ struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0);
+ struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
+ struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
+
+ // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
+ dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt);
+ dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
+
+ // Custom operator to optimize the parallel associative scan
+ // as described in the Annex D of the Mamba paper.
+ // => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined,
+ // because only a single tensor can be returned.
+ struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq);
+
+ // store last states (the second part of y_ssm_states)
+ ggml_build_forward_expand(gf,
+ ggml_cpy(ctx0,
+ ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)),
+ ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_self.head*d_state*d_inner*ggml_element_size(ssm_states))));
+
+ struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
+
+ // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens}
+ y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
+ y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
+
+ // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens}
+ cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y);
+ }
+
+ // residual
+ cur = ggml_add(ctx0, cur, inpL);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ // final rmsnorm
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ // lm_head
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ cb(cur, "result_output", -1);
+
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
};
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
return result;
}
+static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) {
+ llama_batch dummy;
+ dummy.n_tokens = 0;
+
+ llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
+
+ struct llm_build_context llm(lctx, dummy, cb, false);
+
+ llm.init();
+
+ struct ggml_cgraph * result = llm.build_s_copy();
+
+ llm.free();
+
+ return result;
+}
+
static struct ggml_cgraph * llama_build_graph(
llama_context & lctx,
const llama_batch & batch,
{
result = llm.build_starcoder2();
} break;
+ case LLM_ARCH_MAMBA:
+ {
+ result = llm.build_mamba();
+ } break;
default:
GGML_ASSERT(false);
}
}
static void llama_set_k_shift(llama_context & lctx) {
- const auto & cparams = lctx.cparams;
-
- const int64_t n_ctx = cparams.n_ctx;
+ const int64_t kv_size = lctx.kv_self.size;
assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
int32_t * data = (int32_t *) lctx.inp_K_shift->data;
- for (int i = 0; i < n_ctx; ++i) {
+ for (int i = 0; i < kv_size; ++i) {
data[i] = lctx.kv_self.cells[i].delta;
}
}
+static void llama_set_s_copy(llama_context & lctx) {
+ const int64_t kv_size = lctx.kv_self.size;
+
+ assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
+
+ int32_t * data = (int32_t *) lctx.inp_s_copy->data;
+
+ for (int i = 0; i < kv_size; ++i) {
+ data[i] = lctx.kv_self.cells[i].src;
+ }
+}
+
static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
//
// set input data
float * data = (float *) lctx.inp_KQ_mask->data;
+ // For causal attention, use only the previous KV cells
+ // of the correct sequence for each token of the batch.
+ // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
}
}
}
+
+ if (kv_self.recurrent) {
+ const int64_t n_kv = kv_self.n;
+
+ {
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
+ float * data = (float *) lctx.inp_s_mask->data;
+
+ // states which are not affected by the current batch are left untouched
+ for (int i = 0; i < n_kv; ++i) {
+ llama_seq_id seq_id = i + lctx.kv_self.head;
+ llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
+ bool has_self_seq = kv_cell.has_seq_id(seq_id);
+
+ data[i] = (float) has_self_seq;
+
+ // ensure current sequences will be kept
+ if (!has_self_seq && kv_cell.pos >= 0) {
+ kv_cell.seq_id.insert(seq_id);
+ }
+ }
+ }
+ // For Mamba (and other recurrent architectures),
+ // update the correct state(s)/sequence(s) for each token of the batch.
+ // Like with the KQ_mask, if a token in the batch has multiple sequences,
+ // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv).
+ {
+ const int64_t n_tokens = batch.n_tokens;
+
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_seq->buffer));
+ int32_t * data = (int32_t *) lctx.inp_s_seq->data;
+
+ for (int j = 0; j < n_tokens; ++j) {
+ const int32_t n_seq = batch.n_seq_id[j];
+ GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence
+
+ for (int i = 0; i < n_kv; ++i) {
+ if (i < n_seq) {
+ // for this type of model, the head is the minimum seq_id of the batch
+ data[j*n_kv + i] = batch.seq_id[j][i] - kv_self.head;
+ } else {
+ data[j*n_kv + i] = -1;
+ }
+ }
+ }
+ }
+ }
}
static void llama_graph_compute(
return 1;
}
- // a heuristic, to avoid attending the full cache if it is not yet utilized
- // after enough generations, the benefit from this heuristic disappears
- // if we start defragmenting the cache, the benefit from this will be more important
- kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
- //kv_self.n = llama_kv_cache_cell_max(kv_self);
+ if (!kv_self.recurrent) {
+ // a heuristic, to avoid attending the full cache if it is not yet utilized
+ // after enough generations, the benefit from this heuristic disappears
+ // if we start defragmenting the cache, the benefit from this will be more important
+ kv_self.n = std::min(kv_self.size, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
+ //kv_self.n = llama_kv_cache_cell_max(kv_self);
+ }
}
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
}
}
+ if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
+ llama_set_s_copy(lctx);
+
+ {
+ ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
+
+ llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
+ }
+
+ {
+ auto & kv_self = lctx.kv_self;
+
+ kv_self.do_copy = false;
+
+ for (uint32_t i = 0; i < kv_self.size; ++i) {
+ kv_self.cells[i].src = i;
+ }
+ }
+ }
+
// defragment the KV cache if needed
if (lctx.kv_self.do_defrag) {
llama_kv_cache_defrag_internal(lctx);
quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight");
quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
+ // do not quantize Mamba's small yet 2D weights
+ // NOTE: can't use LLM_TN here because the layer number is not known
+ quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
+ quantize &= name.find("ssm_x.weight") == std::string::npos;
+ quantize &= name.find("ssm_dt.weight") == std::string::npos;
+
enum ggml_type new_type;
void * new_data;
size_t new_size;
/*.seed =*/ LLAMA_DEFAULT_SEED,
/*.n_ctx =*/ 512,
/*.n_batch =*/ 512,
+ /*.n_parallel =*/ 1,
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
auto & cparams = ctx->cparams;
cparams.n_batch = params.n_batch;
+ // TODO: maybe add n_parallel here too
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
ctx->rng = std::mt19937(params.seed);
ctx->logits_all = params.logits_all;
- const ggml_type type_k = params.type_k;
- const ggml_type type_v = params.type_v;
+ uint32_t kv_size = cparams.n_ctx;
+ ggml_type type_k = params.type_k;
+ ggml_type type_v = params.type_v;
+
+ // Mamba only needs a constant number of KV cache cells per sequence
+ if (model->arch == LLM_ARCH_MAMBA) {
+ // Mamba needs at least as many KV cells as there are sequences kept at any time
+ kv_size = std::max((uint32_t) 1, params.n_parallel);
+ // it's probably best to keep as much precision as possible for the states
+ type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
+ type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
+ }
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
}
ctx->backends.push_back(ctx->backend_cpu);
- if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, cparams.n_ctx, cparams.offload_kqv)) {
+ if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
return nullptr;
// graph inputs
{
ggml_init_params init_params = {
- /* .mem_size */ ggml_tensor_overhead()*8,
+ /* .mem_size */ ggml_tensor_overhead()*(8 + 3*(ctx->kv_self.recurrent)),
/* .mem_buffer */ nullptr,
/* .no_alloc */ true,
};
ctx->inp_tokens = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch);
ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
- ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
- ctx->inp_KQ_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx);
- ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx);
+ ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, kv_size, cparams.n_batch);
+ ctx->inp_KQ_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
+ ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
+ if (ctx->kv_self.recurrent) {
+ ctx->inp_s_copy = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
+ ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
+ ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch);
+ }
ggml_set_name(ctx->inp_tokens, "inp_tokens");
ggml_set_name(ctx->inp_embd, "inp_embd");
ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
ggml_set_name(ctx->inp_mean, "inp_mean");
ggml_set_name(ctx->inp_cls, "inp_cls");
+ if (ctx->kv_self.recurrent) {
+ ggml_set_name(ctx->inp_s_copy, "inp_s_copy");
+ ggml_set_name(ctx->inp_s_mask, "inp_s_mask");
+ ggml_set_name(ctx->inp_s_seq, "inp_s_seq");
+ }
ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));
LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__,
return ctx->cparams.n_batch;
}
+uint32_t llama_n_max_seq(const struct llama_context * ctx) {
+ return ctx->kv_self.size;
+}
+
enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
return model->vocab.type;
}
case LLM_ARCH_MPT:
case LLM_ARCH_REFACT:
case LLM_ARCH_BLOOM:
+ case LLM_ARCH_MAMBA:
return LLAMA_ROPE_TYPE_NONE;
// use what we call a normal RoPE, operating on pairs of consecutive head values
llama_kv_cache_clear(ctx->kv_self);
}
-void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
- llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
+bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+ return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
}
void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
const auto & hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer;
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
const size_t kv_buf_size = kv_self.total_size();
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size());
data_ctx->write(tmp_buf.data(), tmp_buf.size());
+ if (kv_self.recurrent) {
+ // v is contiguous for recurrent models
+ // TODO: use other tensors for state models than k and v
+ const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
+
+ tmp_buf.resize(v_size);
+ ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), 0, tmp_buf.size());
+ data_ctx->write(tmp_buf.data(), tmp_buf.size());
+ continue;
+ }
+
// v is not contiguous, copy row by row
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
const auto & hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer;
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
size_t kv_buf_size;
uint32_t kv_head;
ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
inp += k_size;
+ if (kv_self.recurrent) {
+ // v is contiguous for recurrent models
+ // TODO: use other tensors for state models than k and v
+ const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
+
+ ggml_backend_tensor_set(kv_self.v_l[il], inp, 0, v_size);
+ inp += v_size;
+ continue;
+ }
+
// v is not contiguous, copy row by row
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);