}
}
-void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
- GGML_UNUSED(ubatch);
-
- const int64_t n_kv = kv_state->get_n_kv();
-
- if (s_mask) {
- GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
- float * data = (float *) s_mask->data;
-
- // clear unused states
- for (int i = 0; i < n_kv; ++i) {
- data[i] = kv_state->s_mask(i);
- }
- }
-}
-
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
return cur;
}
-ggml_tensor * llm_graph_context::build_inp_s_mask() const {
- const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
-
- auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
-
- const auto n_kv = kv_state->get_n_kv();
-
- auto & cur = inp->s_mask;
-
- cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
- ggml_set_input(cur);
-
- res->add_input(std::move(inp));
-
- return cur;
-}
-
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
return cur;
}
-ggml_tensor * llm_graph_context::build_copy_mask_state(
+ggml_tensor * llm_graph_context::build_recurrent_state(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
- ggml_tensor * state_mask,
- int32_t n_state,
- int32_t n_seqs) const {
+ int32_t state_size,
+ int32_t n_seqs,
+ bool avoid_copies) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
const auto n_kv = kv_state->get_n_kv();
const auto kv_head = kv_state->get_head();
+ const auto rs_zero = kv_state->get_rs_z();
- ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
+ ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
- // copy states
- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
- // this shrinks the tensors's ne[1] to n_kv
- states = ggml_get_rows(ctx0, states, state_copy);
+ // Clear a single state which will then be copied to the other cleared states.
+ // Note that this is a no-op when the view is zero-sized.
+ ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
+ ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
- // clear states of sequences which are starting at the beginning of this batch
- // FIXME: zero-out NANs?
- states = ggml_mul(ctx0, states, state_mask);
+ ggml_tensor * output_states;
+
+ if (!avoid_copies) {
+ // copy states
+ // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
+ // {state_size, kv_size} -> {state_size, n_seqs}
+ output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
+ ggml_build_forward_expand(gf, output_states);
+ } else {
+ // FIXME: make the gathering operation happen before the copy below
+ // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
+ output_states = states;
+ }
- // copy states which won't be changed further (between n_seqs and n_kv)
+ // copy extra states which won't be changed further (between n_seqs and n_kv)
+ ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
ggml_build_forward_expand(gf,
ggml_cpy(ctx0,
- ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
- ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
+ states_extra,
+ ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
- // the part of the states that will be used and modified
- return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
+ return output_states;
}
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
ggml_cgraph * gf,
ggml_tensor * state_copy,
- ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
- ggml_tensor * token_shift = build_copy_mask_state(
- gf, token_shift_all, state_copy, state_mask,
+ ggml_tensor * token_shift = build_recurrent_state(
+ gf, token_shift_all, state_copy,
hparams.n_embd_k_s(), n_seqs);
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
const llama_kv_cache_recurrent_state * kv_state;
};
-class llm_graph_input_s_mask : public llm_graph_input_i {
-public:
- llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
- virtual ~llm_graph_input_s_mask() = default;
-
- void set_input(const llama_ubatch * ubatch) override;
-
- ggml_tensor * s_mask; // F32 [1, n_kv]
-
- const llama_kv_cache_recurrent_state * kv_state;
-};
-
class llm_graph_input_cross_embd : public llm_graph_input_i {
public:
llm_graph_input_cross_embd(
ggml_tensor * build_inp_mean() const;
ggml_tensor * build_inp_cls() const;
ggml_tensor * build_inp_s_copy() const;
- ggml_tensor * build_inp_s_mask() const;
ggml_tensor * build_inp_cross_embd() const;
ggml_tensor * build_inp_pos_bucket_enc() const;
// recurrent
//
- ggml_tensor * build_copy_mask_state(
+ ggml_tensor * build_recurrent_state(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
- ggml_tensor * state_mask,
- int32_t n_state,
- int32_t n_seqs) const;
+ int32_t state_size,
+ int32_t n_seqs,
+ bool avoid_copies = false) const;
ggml_tensor * build_rwkv_token_shift_load(
ggml_cgraph * gf,
ggml_tensor * state_copy,
- ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const;
bool success = true;
- // TODO: here we have to verify that all ubatches can fit in the cells
- // however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells
- // during the compute of each ubatch. to reproduce, uncomment the following loop and run:
- //
- // $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8
- //
- // recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
- //
- GGML_UNUSED(ubatches);
- //for (const auto & ubatch : ubatches) {
- // if (!find_slot(ubatch)) {
- // success = false;
- // break;
- // }
- //}
+ for (const auto & ubatch : ubatches) {
+ if (!find_slot(ubatch)) {
+ success = false;
+ break;
+ }
+ }
// restore the original state
cells = std::move(org_cells);
}
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
- const uint32_t n_tokens = ubatch.n_tokens;
- const uint32_t n_seqs = ubatch.n_seqs;
+ const uint32_t n_seqs = ubatch.n_seqs;
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
- if (head > used + 2*n_tokens) {
+ if (head > used + 2*n_seqs) {
head = 0;
}
empty_cell.src = orig_cell.src;
orig_cell.seq_id.erase(seq_id);
empty_cell.seq_id.insert(seq_id); // will be overwritten
+ GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
}
seq_meta.tail = next_empty_cell;
// find next empty cell
if (s + 1 < n_seqs) {
- next_empty_cell += 1;
for (uint32_t i = 0; i < size; ++i) {
+ next_empty_cell += 1;
if (next_empty_cell >= size) { next_empty_cell -= size; }
kv_cell & cell = cells[next_empty_cell];
if (cell.is_empty()) { break; }
- next_empty_cell += 1;
}
}
}
// gather and re-order
for (uint32_t s = 0; s < n_seqs; ++s) {
- int32_t dst_id = s + min;
- int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
+ const int32_t dst_id = s + min;
+ const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
if (dst_id != src_id) {
kv_cell & dst_cell = cells[dst_id];
kv_cell & src_cell = cells[src_id];
std::swap(dst_cell.src, src_cell.src);
std::swap(dst_cell.seq_id, src_cell.seq_id);
- // swap tails (assuming they NEVER overlap)
- for (const llama_seq_id seq_id : src_cell.seq_id) {
- cells[seq_id].tail = src_id;
- }
- for (const llama_seq_id seq_id : dst_cell.seq_id) {
- cells[seq_id].tail = dst_id;
+ // swap tails
+ for (uint32_t i = 0; i < size; ++i) {
+ int32_t & tail = cells[i].tail;
+ if (tail == src_id) {
+ tail = dst_id;
+ } else if (tail == dst_id) {
+ tail = src_id;
+ }
}
}
}
// update the pos of the used seqs
for (uint32_t s = 0; s < n_seqs; ++s) {
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
- int32_t cell_id = s + min;
+ const int32_t cell_id = s + min;
kv_cell & cell = cells[cell_id];
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
}
}
+ // Find first cell without src refs, to use as the zero-ed state
+ {
+ // TODO: bake-in src refcounts in the cell metadata
+ std::vector<int32_t> refcounts(size, 0);
+ for (size_t i = 0; i < size; ++i) {
+ const int32_t src = cells[i].src;
+ if (src >= 0) {
+ refcounts[src] += 1;
+ }
+ }
+
+ rs_z = -1;
+ for (int i = min; i <= max; ++i) {
+ if (refcounts[i] == 0) {
+ rs_z = i;
+ break;
+ }
+ }
+
+ for (int i = min; i <= max; ++i) {
+ if (cells[i].src < 0) {
+ GGML_ASSERT(rs_z >= 0);
+ cells[i].src0 = rs_z;
+ } else {
+ // Stage the source ids for all used cells to allow correct seq_* behavior
+ // and still make these values available when setting the inputs
+ cells[i].src0 = cells[i].src;
+ }
+ cells[i].src = i; // avoid moving or clearing twice
+ }
+ }
+
// allow getting the range of used cells, from head to head + n
head = min;
n = max - min + 1;
}
bool llama_kv_cache_recurrent::get_can_shift() const {
- return false;
-}
-
-int32_t llama_kv_cache_recurrent::s_copy(int i) const {
- const uint32_t cell_id = i + head;
-
- //////////////////////////////////////////////
- // TODO: this should not mutate the KV cache !
- kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
-
- // prevent out-of-bound sources
- if (cell.src < 0 || (uint32_t) cell.src >= size) {
- cell.src = cell_id;
- }
-
- int32_t res = cell.src;
-
- // TODO: do not mutate the KV cache
- // ensure copy only happens once
- if (cell.src != (int32_t) cell_id) {
- cell.src = cell_id;
- }
-
- return res;
-}
-
-float llama_kv_cache_recurrent::s_mask(int i) const {
- const uint32_t cell_id = i + head;
-
- //////////////////////////////////////////////
- // TODO: this should not mutate the KV cache !
- kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
-
- float res = (float) (cell.src >= 0);
-
- // only clear once
- if (cell.src < 0) {
- cell.src = cell_id;
- }
-
- return res;
+ // shifting the pos is trivial for recurrent models
+ return true;
}
size_t llama_kv_cache_recurrent::total_size() const {
return is_full ? 0 : kv->head;
}
+int32_t llama_kv_cache_recurrent_state::get_rs_z() const {
+ return is_full ? 0 : kv->rs_z;
+}
+
uint32_t llama_kv_cache_recurrent_state::get_size() const {
return kv->size;
}
}
int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
- return kv->s_copy(i);
-}
-
-float llama_kv_cache_recurrent_state::s_mask(int i) const {
- return kv->s_mask(i);
+ return kv->cells[i + kv->head].src0;
}
bool get_can_shift() const override;
- // TODO: temporary methods - they are not really const as they do const_cast<>, fix this
- int32_t s_copy(int i) const;
- float s_mask(int i) const;
-
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
// computed before each graph build
uint32_t n = 0;
+ // first zero-ed state
+ int32_t rs_z = -1;
+
// TODO: optimize for recurrent state needs
struct kv_cell {
llama_pos pos = -1;
- int32_t src = -1; // used to copy states
+ int32_t src = -1; // used to know where states should be copied from
+ int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
int32_t tail = -1;
std::set<llama_seq_id> seq_id;
uint32_t get_n_kv() const;
uint32_t get_head() const;
+ int32_t get_rs_z() const;
uint32_t get_size() const;
ggml_tensor * get_k_l(int32_t il) const;
ggml_tensor * get_v_l(int32_t il) const;
int32_t s_copy(int i) const;
- float s_mask(int i) const;
private:
const llama_memory_status status;
head_cur = 0;
}
- // otherwise, one cell per token.
-
if (n_tokens > cells.size()) {
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
return -1;
inpL = build_inp_embd(model.tok_embd);
ggml_tensor * state_copy = build_inp_s_copy();
- ggml_tensor * state_mask = build_inp_s_mask();
for (int il = 0; il < n_layer; ++il) {
// norm
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
- //cur = build_mamba_layer(gf, cur, state_copy, state_mask, il);
- cur = build_mamba_layer(gf, cur, state_copy, state_mask, ubatch, il);
+ cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
if (il == n_layer - 1) {
// skip computing output for unused tokens
ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * state_copy,
- ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
// (ab)using the KV cache to store the states
- ggml_tensor * conv = build_copy_mask_state(
- gf, conv_states_all, state_copy, state_mask,
+ ggml_tensor * conv = build_recurrent_state(
+ gf, conv_states_all, state_copy,
hparams.n_embd_k_s(), n_seqs);
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
- ggml_tensor * ssm = build_copy_mask_state(
- gf, ssm_states_all, state_copy, state_mask,
+ ggml_tensor * ssm = build_recurrent_state(
+ gf, ssm_states_all, state_copy,
hparams.n_embd_v_s(), n_seqs);
ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
ggml_tensor * cur,
ggml_tensor * x_prev,
ggml_tensor * state_copy,
- ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
}
- ggml_tensor * wkv_state = build_copy_mask_state(
- gf, kv_state->get_v_l(il), state_copy, state_mask,
+ ggml_tensor * wkv_state = build_recurrent_state(
+ gf, kv_state->get_v_l(il), state_copy,
hparams.n_embd_v_s(), n_seqs);
ggml_tensor * wkv_output;
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
ggml_tensor * state_copy = build_inp_s_copy();
- ggml_tensor * state_mask = build_inp_s_mask();
const auto n_embd = hparams.n_embd;
const auto n_seq_tokens = ubatch.n_seq_tokens;
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
ggml_tensor * token_shift = build_rwkv_token_shift_load(
- gf, state_copy, state_mask, ubatch, il
+ gf, state_copy, ubatch, il
);
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
1
);
- cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il);
+ cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
inpL = build_inp_embd(model.tok_embd);
ggml_tensor * state_copy = build_inp_s_copy();
- ggml_tensor * state_mask = build_inp_s_mask();
const auto n_embd = hparams.n_embd;
const auto n_seq_tokens = ubatch.n_seq_tokens;
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
ggml_tensor * token_shift = build_rwkv_token_shift_load(
- gf, state_copy, state_mask, ubatch, il
+ gf, state_copy, ubatch, il
);
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
1
);
- cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il);
+ cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
ggml_tensor * cur,
ggml_tensor * x_prev,
ggml_tensor * state_copy,
- ggml_tensor * state_mask,
ggml_tensor *& first_layer_value,
const llama_ubatch & ubatch,
int il) const {
v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
- ggml_tensor * wkv_state = build_copy_mask_state(
- gf, kv_state->get_v_l(il), state_copy, state_mask,
+ ggml_tensor * wkv_state = build_recurrent_state(
+ gf, kv_state->get_v_l(il), state_copy,
hparams.n_embd_v_s(), n_seqs);
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
ggml_tensor * state_copy = build_inp_s_copy();
- ggml_tensor * state_mask = build_inp_s_mask();
const auto n_embd = hparams.n_embd;
const auto n_seq_tokens = ubatch.n_seq_tokens;
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
ggml_tensor * token_shift = build_rwkv_token_shift_load(
- gf, state_copy, state_mask, ubatch, il
+ gf, state_copy, ubatch, il
);
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
1
);
- cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il);
+ cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
inpL = build_inp_embd(model.tok_embd);
ggml_tensor * state_copy = build_inp_s_copy();
- ggml_tensor * state_mask = build_inp_s_mask();
const auto n_embd = hparams.n_embd;
const auto n_seq_tokens = ubatch.n_seq_tokens;
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
ggml_tensor * token_shift = build_rwkv_token_shift_load(
- gf, state_copy, state_mask, ubatch, il
+ gf, state_copy, ubatch, il
);
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
1
);
- cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il);
+ cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));