};
head = 0;
- size = kv_size;
- used = 0;
cells.resize(kv_size);
}
void llama_kv_cache_unified::clear() {
- for (uint32_t i = 0; i < size; ++i) {
- cells[i].pos = -1;
- cells[i].seq_id.clear();
- }
+ cells.reset();
head = 0;
- used = 0;
for (auto & buf : bufs) {
ggml_backend_buffer_clear(buf.get(), 0);
}
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
- uint32_t new_head = size;
+ uint32_t new_head = cells.size();
if (p0 < 0) {
p0 = 0;
p1 = std::numeric_limits<llama_pos>::max();
}
- for (uint32_t i = 0; i < size; ++i) {
- if (cells[i].pos >= p0 && cells[i].pos < p1) {
- if (seq_id < 0) {
- cells[i].seq_id.clear();
- } else if (cells[i].has_seq_id(seq_id)) {
- cells[i].seq_id.erase(seq_id);
- } else {
- continue;
- }
-
- if (cells[i].is_empty()) {
- // keep count of the number of used cells
- if (cells[i].pos >= 0) {
- used--;
- }
-
- cells[i].pos = -1;
+ for (uint32_t i = 0; i < cells.size(); ++i) {
+ if (!cells.pos_in(i, p0, p1)) {
+ continue;
+ }
- if (new_head == size) {
- new_head = i;
- }
+ if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
+ if (new_head == cells.size()) {
+ new_head = i;
}
}
}
// If we freed up a slot, set head to it so searching can start there.
- if (new_head != size && new_head < head) {
+ if (new_head != cells.size() && new_head < head) {
head = new_head;
}
p1 = std::numeric_limits<llama_pos>::max();
}
- // otherwise, this is the KV of a Transformer-like model
- head = 0;
+ for (uint32_t i = 0; i < cells.size(); ++i) {
+ if (!cells.pos_in(i, p0, p1)) {
+ continue;
+ }
- for (uint32_t i = 0; i < size; ++i) {
- if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) {
- cells[i].seq_id.insert(seq_id_dst);
+ if (cells.seq_has(i, seq_id_src)) {
+ cells.seq_add(i, seq_id_dst);
}
}
}
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
- uint32_t new_head = size;
+ uint32_t new_head = cells.size();
- for (uint32_t i = 0; i < size; ++i) {
- if (!cells[i].has_seq_id(seq_id)) {
- if (cells[i].pos >= 0) {
- used--;
- }
-
- cells[i].pos = -1;
- cells[i].seq_id.clear();
-
- if (new_head == size){
+ for (uint32_t i = 0; i < cells.size(); ++i) {
+ if (cells.seq_keep(i, seq_id)) {
+ if (new_head == cells.size()) {
new_head = i;
}
- } else {
- cells[i].seq_id.clear();
- cells[i].seq_id.insert(seq_id);
}
}
// If we freed up a slot, set head to it so searching can start there.
- if (new_head != size && new_head < head) {
+ if (new_head != cells.size() && new_head < head) {
head = new_head;
}
}
-void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
- if (delta == 0) {
+void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+ if (shift == 0) {
return;
}
- uint32_t new_head = size;
+ uint32_t new_head = cells.size();
if (p0 < 0) {
p0 = 0;
p1 = std::numeric_limits<llama_pos>::max();
}
- // If there is no range then return early to avoid looping over the
+ // If there is no range then return early to avoid looping over all cells.
if (p0 == p1) {
return;
}
- for (uint32_t i = 0; i < size; ++i) {
- if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
- has_shift = true;
-
- cells[i].pos += delta;
- cells[i].delta += delta;
+ for (uint32_t i = 0; i < cells.size(); ++i) {
+ if (!cells.pos_in(i, p0, p1)) {
+ continue;
+ }
- if (cells[i].pos < 0) {
- if (!cells[i].is_empty()) {
- used--;
- }
- cells[i].pos = -1;
- cells[i].seq_id.clear();
- if (new_head == size) {
+ if (cells.seq_has(i, seq_id)) {
+ if (cells.pos_add(i, shift)) {
+ if (new_head == cells.size()) {
new_head = i;
}
}
// If we freed up a slot, set head to it so searching can start there.
// Otherwise we just start the next search from the beginning.
- head = new_head != size ? new_head : 0;
+ head = new_head != cells.size() ? new_head : 0;
}
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
return;
}
- for (uint32_t i = 0; i < size; ++i) {
- if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
- has_shift = true;
+ for (uint32_t i = 0; i < cells.size(); ++i) {
+ if (!cells.pos_in(i, p0, p1)) {
+ continue;
+ }
- {
- llama_pos p_old = cells[i].pos;
- cells[i].pos /= d;
- cells[i].delta += cells[i].pos - p_old;
- }
+ if (cells.seq_has(i, seq_id)) {
+ cells.pos_div(i, d);
}
}
}
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
llama_pos result = std::numeric_limits<llama_pos>::max();
- for (uint32_t i = 0; i < size; ++i) {
- if (cells[i].has_seq_id(seq_id)) {
- result = std::min(result, cells[i].pos);
+ for (uint32_t i = 0; i < cells.size(); ++i) {
+ if (cells.seq_has(i, seq_id)) {
+ result = std::min(result, cells.pos_get(i));
}
}
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
llama_pos result = -1;
- for (uint32_t i = 0; i < size; ++i) {
- if (cells[i].has_seq_id(seq_id)) {
- result = std::max(result, cells[i].pos);
+ for (uint32_t i = 0; i < cells.size(); ++i) {
+ if (cells.seq_has(i, seq_id)) {
+ result = std::max(result, cells.pos_get(i));
}
}
}
void llama_kv_cache_unified::restore() {
- for (const auto & [id, cell] : recovery.cells) {
- // TODO: move to new `struct kv_cells`
- const bool is_empty0 = cells[id].is_empty();
- const bool is_empty1 = cell.is_empty();
-
- if (!is_empty0 && is_empty1) {
- used--;
- } else if (is_empty0 && !is_empty1) {
- used++;
- }
-
- cells[id] = cell;
+ for (auto & state : recovery.states) {
+ cells.set(state.i, state.cells);
}
recovery.clear();
}
void llama_kv_cache_unified::commit() {
- if (recovery.cells.empty()) {
+ if (recovery.states.empty()) {
LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
return;
auto * sched = lctx.get_sched();
- if (has_shift) {
+ if (cells.get_has_shift()) {
if (!get_can_shift()) {
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
}
need_reserve = true;
}
- {
- has_shift = false;
-
- for (uint32_t i = 0; i < size; ++i) {
- cells[i].delta = 0;
- }
- }
+ cells.reset_shift();
}
if (do_defrag) {
void llama_kv_cache_unified::defrag_sched(float thold) {
// - do not defrag small contexts (i.e. < 2048 tokens)
// - count the padding towards the number of used tokens
- const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + n_pad)/n)) : 0.0f;
+ const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n)) : 0.0f;
// queue defragmentation for next llama_kv_cache_update
if (fragmentation > thold) {
}
void llama_kv_cache_unified::set_full() {
- n = size;
+ n = cells.size();
// when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
// affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
- if (head > used + 2*ubatch.n_tokens) {
+ if (head > cells.get_used() + 2*ubatch.n_tokens) {
head = 0;
}
// otherwise, one cell per token.
- if (n_tokens > size) {
- LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
+ if (n_tokens > cells.size()) {
+ LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
return false;
}
std::string ss;
if (n_swa > 0) {
for (uint32_t i = 0; i < size; ++i) {
- if (cells[i].pos == -1) {
+ if (cells.is_empty(i)) {
ss += '.';
} else {
- ss += std::to_string(*cells[i].seq_id.begin());
+ ss += 'x';
}
if (i%256 == 255) {
ss += '\n';
uint32_t n_tested = 0;
while (true) {
- if (head + n_tokens > size) {
- n_tested += size - head;
+ if (head + n_tokens > cells.size()) {
+ n_tested += cells.size() - head;
head = 0;
continue;
}
bool found = true;
for (uint32_t i = 0; i < n_tokens; i++) {
- if (cells[head + i].pos >= 0) {
+ // TODO: improve to accept cells that are masked by the SWA
+ if (!cells.is_empty(head + i)) {
found = false;
head += i + 1;
n_tested += i + 1;
break;
}
- if (n_tested >= size) {
+ if (n_tested >= cells.size()) {
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
return false;
}
}
- for (uint32_t i = 0; i < n_tokens; ++i) {
- // remember the original state
- if (recovery.cells.find(head + i) == recovery.cells.end()) {
- recovery.cells[head + i] = cells[head + i];
- }
+ // store the old state of the cells in the recovery stack
+ recovery.states.push_back({head, cells.cp(head, n_tokens)});
- cells[head + i].pos = ubatch.pos[i];
+ for (uint32_t i = 0; i < n_tokens; ++i) {
+ cells.pos_set(head + i, ubatch.pos[i]);
for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
- cells[head + i].seq_id.insert(ubatch.seq_id[i][j]);
+ cells.seq_add(head + i, ubatch.seq_id[i][j]);
}
}
- used += n_tokens;
-
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
- n = std::min(size, std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
+ n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
#ifdef FIND_SLOT_DEBUG
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
}
uint32_t llama_kv_cache_unified::get_size() const {
- return size;
+ return cells.size();
}
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const {
int n_attended = 0;
- for (uint32_t i = 0; i < size; ++i) {
- const llama_pos p0 = cells[i].pos;
+ for (uint32_t i = 0; i < cells.size(); ++i) {
+ if (!cells.seq_has(i, seq_id)) {
+ continue;
+ }
+
+ const llama_pos p0 = cells.pos_get(i);
if (p0 <= pmin && !is_masked_swa(p0, pmin)) {
n_attended++;
}
if (is_masked_swa(p0, pmax)) {
- if (seq_id < 0) {
- cells[i].seq_id.clear();
- } else if (cells[i].has_seq_id(seq_id)) {
- cells[i].seq_id.erase(seq_id);
- } else {
- continue;
- }
-
- if (cells[i].is_empty()) {
- // keep count of the number of used cells
- if (cells[i].pos >= 0) {
- used--;
- }
-
- cells[i].pos = -1;
- }
+ cells.seq_rm(i, seq_id);
}
}
const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
for (int i = 0; i < n_kv; ++i) {
- const llama_pos p0 = cells[i].pos;
+ float f = 0.0f;
bool masked = false;
- // mask the token if not the same sequence
- masked = masked || (!cells[i].has_seq_id(seq_id));
+ if (cells.is_empty(i)) {
+ masked = true;
+ } else {
+ const llama_pos p0 = cells.pos_get(i);
- // mask future tokens
- masked = masked || (causal_attn && p0 > p1);
+ // mask the token if not the same sequence
+ masked = masked || (!cells.seq_has(i, seq_id));
- // apply SWA if any
- masked = masked || (is_masked_swa(p0, p1));
+ // mask future tokens
+ masked = masked || (causal_attn && p0 > p1);
- float f = 0.0f;
+ // apply SWA if any
+ masked = masked || (is_masked_swa(p0, p1));
+
+ if (!masked && hparams.use_alibi) {
+ f = -std::abs(p0 - p1);
+ }
+ }
if (masked) {
f = -INFINITY;
- } else if (hparams.use_alibi) {
- f = -std::abs(p0 - p1);
}
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
int32_t * data = (int32_t *) dst->data;
- for (uint32_t i = 0; i < size; ++i) {
- data[i] = cells[i].delta;
+ for (uint32_t i = 0; i < cells.size(); ++i) {
+ data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
}
}
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
for (int i = 0; i < n_kv; ++i) {
- data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
+ // the position when the cells is empty is irrelevant - it will be masked out later in the attention
+ const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
+
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
}
}
}
ggml_tensor * k =
ggml_view_3d(ctx, layer.k,
- n_embd_head_k, n_head_kv, size,
+ n_embd_head_k, n_head_kv, cells.size(),
ggml_row_size(layer.k->type, n_embd_head_k),
ggml_row_size(layer.k->type, n_embd_k_gqa),
0);
} else {
view_v_src = ggml_view_2d(ctx, layer.v,
nm, n_embd_v_gqa,
- ggml_row_size(layer.v->type, size),
+ ggml_row_size(layer.v->type, cells.size()),
ggml_row_size(layer.v->type, i));
view_v_dst = ggml_view_2d(ctx, layer.v,
nm, n_embd_v_gqa,
- ggml_row_size(layer.v->type, size),
+ ggml_row_size(layer.v->type, cells.size()),
ggml_row_size(layer.v->type, id));
}
const uint32_t n_layer = layers.size();
const uint32_t n_kv = cell_max();
- const uint32_t n_used = used;
+ const uint32_t n_used = cells.get_used();
assert(n_used <= n_kv);
ids.resize(n_kv, n_kv);
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
- const auto & cell0 = cells[i0];
-
- if (!cell0.is_empty()) {
+ if (!cells.is_empty(i0)) {
ids[i0] = i0;
continue;
uint32_t nh = 1;
// determine the size of the hole
- while (i0 + nh < n_used && cells[i0 + nh].is_empty()) {
+ while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
nh++;
}
// starting from the end, find nh non-empty cells
for (; is > i0; --is) {
- const auto & cell1 = cells[is];
-
- if (cell1.is_empty() || ids[is] != n_kv) {
+ if (cells.is_empty(is) || ids[is] != n_kv) {
continue;
}
// go back and move the nf cells to the hole
for (; i1 < n_kv; ++i1) {
- auto & cell1 = cells[i1];
-
- if (cell1.is_empty() || ids[i1] != n_kv) {
+ if (cells.is_empty(i1) || ids[i1] != n_kv) {
if (n_moves == max_moves) {
stop = true;
break;
ids[i1] = i0 + nf;
// move the cell meta data
- cells[i0 + nf] = cell1;
+ cells.mv(i1, i0 + nf);
- // clear the old cell and move the head there
- cell1 = kv_cell();
head = n_used;
if (!cont) {
}
uint32_t llama_kv_cache_unified::cell_max() const {
- for (uint32_t i = size; i > 0; --i) {
- const kv_cell & cell = cells[i - 1];
-
- if (cell.pos >= 0 && !cell.is_empty()) {
+ for (uint32_t i = cells.size(); i > 0; --i) {
+ if (!cells.is_empty(i - 1)) {
return i;
}
}
}
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
- if (p0 < 0) {
- return true;
- }
+ assert(p0 >= 0 && p1 >= 0);
switch (swa_type) {
case LLAMA_SWA_TYPE_NONE:
// Count the number of cells with the specified seq_id
// Find all the ranges of cells with this seq id (or all, when -1)
- uint32_t cell_range_begin = size;
- for (uint32_t i = 0; i < size; ++i) {
- const auto & cell = cells[i];
- if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
+ uint32_t cell_range_begin = cells.size();
+
+ for (uint32_t i = 0; i < cells.size(); ++i) {
+ if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
++cell_count;
- if (cell_range_begin == size) {
+ if (cell_range_begin == cells.size()) {
cell_range_begin = i;
}
} else {
- if (cell_range_begin != size) {
+ if (cell_range_begin != cells.size()) {
cell_ranges.emplace_back(cell_range_begin, i);
- cell_range_begin = size;
+ cell_range_begin = cells.size();
}
}
}
- if (cell_range_begin != size) {
- cell_ranges.emplace_back(cell_range_begin, size);
+
+ if (cell_range_begin != cells.size()) {
+ cell_ranges.emplace_back(cell_range_begin, cells.size());
}
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
for (const auto & range : cell_ranges) {
for (uint32_t i = range.first; i < range.second; ++i) {
- const auto & cell = cells[i];
- const llama_pos pos = cell.pos;
- const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
+ std::vector<llama_seq_id> seq_ids;
+
+ for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
+ if (cur == seq_id || seq_id == -1) {
+ if (cells.seq_has(i, cur)) {
+ seq_ids.push_back(cur);
+ }
+ }
+ }
+
+ const llama_pos pos = cells.pos_get(i);
+ const uint32_t n_seq_id = seq_ids.size();
io.write(&pos, sizeof(pos));
io.write(&n_seq_id, sizeof(n_seq_id));
- if (n_seq_id) {
- for (auto seq_id : cell.seq_id) {
- io.write(&seq_id, sizeof(seq_id));
- }
+ for (const auto & seq_id : seq_ids) {
+ io.write(&seq_id, sizeof(seq_id));
}
}
}
}
} else {
// When v is transposed, we also need the element size and get the element ranges from each row
- const uint32_t kv_size = size;
+ const uint32_t kv_size = cells.size();
for (const auto & layer : layers) {
const uint32_t il = layer.il;
io.read_to(&pos, sizeof(pos));
io.read_to(&n_seq_id, sizeof(n_seq_id));
- if (n_seq_id != 0) {
+ if (n_seq_id != 1) {
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
return false;
}
- batch.pos[i] = pos;
- batch.n_seq_id[i] = 1;
- batch.seq_id[i] = &dest_seq_id;
+ // read the sequence id, but directly discard it - we will use dest_seq_id instead
+ {
+ llama_seq_id seq_id;
+ io.read_to(&seq_id, sizeof(seq_id));
+ }
+
+ batch.pos[i] = pos;
+ batch.n_seq_id[i] = n_seq_id;
+ batch.seq_id[i] = &dest_seq_id;
}
if (!find_slot(batch)) {
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
// Assume that this is one contiguous block of cells
- GGML_ASSERT(head + cell_count <= size);
- GGML_ASSERT(cells[head].pos == batch.pos[0]);
- GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
- GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
- GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
+ GGML_ASSERT(head + cell_count <= cells.size());
+ GGML_ASSERT(cells.pos_get(head) == batch.pos[0]);
+ GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]);
+ GGML_ASSERT(cells.seq_has(head, dest_seq_id));
+ GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id));
} else {
// whole KV cache restore
- if (cell_count > size) {
+ if (cell_count > cells.size()) {
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
return false;
}
clear();
for (uint32_t i = 0; i < cell_count; ++i) {
- kv_cell & cell = cells[i];
-
llama_pos pos;
uint32_t n_seq_id;
io.read_to(&pos, sizeof(pos));
io.read_to(&n_seq_id, sizeof(n_seq_id));
- cell.pos = pos;
+ cells.pos_set(i, pos);
for (uint32_t j = 0; j < n_seq_id; ++j) {
llama_seq_id seq_id;
return false;
}
- cell.seq_id.insert(seq_id);
+ cells.seq_add(i, seq_id);
}
}
head = 0;
- used = cell_count;
}
return true;
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
return false;
}
- if (cell_count > size) {
- LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
+ if (cell_count > cells.size()) {
+ LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
return false;
}
if (this->v_trans != (bool) v_trans) {
if (cell_count) {
// For each row in the transposed matrix, read the values for the whole cell range
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
- const size_t dst_offset = (head + j * size) * v_size_el;
+ const size_t dst_offset = (head + j * cells.size()) * v_size_el;
ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
}
}
kv_swa ->seq_keep(seq_id);
}
-void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
- kv_base->seq_add(seq_id, p0, p1, delta);
- kv_swa ->seq_add(seq_id, p0, p1, delta);
+void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+ kv_base->seq_add(seq_id, p0, p1, shift);
+ kv_swa ->seq_add(seq_id, p0, p1, shift);
}
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
}
}
-void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
- if (delta == 0) {
+void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+ if (shift == 0) {
return;
}
if (tail_id >= 0) {
kv_cell & cell = cells[tail_id];
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
- cell.pos += delta;
+ cell.pos += shift;
}
}
}
#include "llama-io.h"
#include "llama-graph.h"
#include "llama-memory.h"
+#include "llama-kv-cells.h"
#include "ggml-cpp.h"
virtual void defrag_sched(float thold) = 0;
// simulate full cache, used for allocating worst-case compute buffers
+ // TODO: remove
virtual void set_full() = 0;
//
//
// =============================================================================================================
- // TODO: refactor and simplify this
+ // TODO: refactor and simplify this [TAG: KV_API]
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
- void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
// llama_kv_cache_unified specific API
//
- uint32_t get_n() const;
+ uint32_t get_n() const;
uint32_t get_size() const;
// get views of the current state of the cache
const llama_model & model;
const llama_hparams & hparams;
- struct kv_cell {
- llama_pos pos = -1;
- llama_pos delta = 0;
-
- // TODO: replace with bitset uint64_t
- std::set<llama_seq_id> seq_id;
-
- bool has_seq_id(const llama_seq_id & id) const {
- return seq_id.find(id) != seq_id.end();
- }
-
- bool is_empty() const {
- return seq_id.empty();
- }
-
- bool is_same_seq(const kv_cell & other) const {
- return seq_id == other.seq_id;
- }
- };
-
struct kv_layer {
// layer index in the model
// note: can be different from the layer index in the KV cache
ggml_tensor * v;
};
- bool has_shift = false;
bool do_defrag = false;
bool v_trans = true; // the value tensor is transposed
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
- uint32_t size = 0; // total number of cells, shared across all sequences
- uint32_t used = 0; // used cells (i.e. at least one seq_id) (TODO: add `struct kv_cells` and keep track automaticallt)
// computed before each graph build
+ // TODO: cells should start to maintain this value dynamically based on the edits
uint32_t n = 0;
const uint32_t n_seq_max = 1;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
- std::vector<kv_cell> cells; // TODO: replace with `struct kv_cells`
+ llama_kv_cells_unified cells;
+
std::vector<kv_layer> layers;
// model layer id -> KV cache layer id
std::unordered_map<int32_t, int32_t> map_layer_ids;
// recovery information used to restore the KV cells to their original state in case of a failure
+ // TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation
+ // to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API]
struct {
void clear() {
- cells.clear();
+ states.clear();
}
- std::unordered_map<uint32_t, kv_cell> cells;
+ struct state {
+ uint32_t i;
+
+ llama_kv_cells_unified cells;
+ };
+
+ // stack with the partial states before each ubatch
+ std::vector<state> states;
} recovery;
// defrag
bool defrag_prepare(int32_t n_max_nodes);
// find how many cells are currently in use
+ // TODO: optimize
uint32_t cell_max() const;
size_t total_size() const;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
- void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
- void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
--- /dev/null
+#pragma once
+
+#include "llama.h"
+#include "llama-cparams.h"
+
+#include <bitset>
+#include <cassert>
+#include <vector>
+
+// meta information about KV cells that can be part of multiple sequences at the same time
+// TODO: add unit tests
+class llama_kv_cells_unified {
+public:
+ void reset() {
+ for (uint32_t i = 0; i < pos.size(); ++i) {
+ pos[i] = -1;
+ shift[i] = 0;
+ seq[i].reset();
+ }
+
+ used = 0;
+ has_shift = false;
+ }
+
+ void reset_shift() {
+ has_shift = false;
+
+ for (uint32_t i = 0; i < shift.size(); ++i) {
+ shift[i] = 0;
+ }
+ }
+
+ uint32_t size() const {
+ return pos.size();
+ }
+
+ void resize(uint32_t n) {
+ pos.resize(n);
+ shift.resize(n);
+ seq.resize(n);
+
+ reset();
+ }
+
+ bool is_empty(uint32_t i) const {
+ assert(i < pos.size());
+ assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0);
+
+ return pos[i] == -1;
+ }
+
+ uint32_t get_used() const {
+ return used;
+ }
+
+ bool get_has_shift() const {
+ return has_shift;
+ }
+
+ // move cell isrc to idst (used during defrag)
+ void mv(uint32_t isrc, uint32_t idst) {
+ assert(isrc < pos.size());
+ assert(idst < pos.size());
+
+ pos [idst] = pos [isrc];
+ shift[idst] = shift[isrc];
+ seq [idst] = seq [isrc];
+
+ pos [isrc] = -1;
+ shift[isrc] = 0;
+ seq [isrc].reset();
+ }
+
+ // copy the state of cells [i, i + n) (used for save/restore the state of the cells)
+ llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
+ assert(i + n <= pos.size());
+
+ llama_kv_cells_unified res;
+
+ res.resize(n);
+
+ for (uint32_t j = 0; j < n; ++j) {
+ res.pos[j] = pos[i + j];
+ res.seq[j] = seq[i + j];
+
+ assert(shift[i + j] == 0);
+ }
+
+ return res;
+ }
+
+ // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
+ void set(uint32_t i, const llama_kv_cells_unified & other) {
+ assert(i + other.pos.size() <= pos.size());
+
+ for (uint32_t j = 0; j < other.pos.size(); ++j) {
+ if (pos[i + j] == -1 && other.pos[j] != -1) {
+ used++;
+ }
+
+ if (pos[i + j] != -1 && other.pos[j] == -1) {
+ used--;
+ }
+
+ pos[i + j] = other.pos[j];
+ seq[i + j] = other.seq[j];
+
+ assert(shift[i + j] == 0);
+ }
+ }
+
+ // note: call only if the cell has seq_id
+ // return true if the cell becomes empty
+ bool seq_rm(uint32_t i, llama_seq_id seq_id) {
+ assert(i < pos.size());
+ assert(seq[i].test(seq_id));
+ assert(pos[i] != -1);
+ assert(seq_id >= 0);
+
+ seq[i].reset(seq_id);
+
+ if (seq[i].none()) {
+ pos[i] = -1;
+
+ used--;
+
+ return true;
+ }
+
+ return false;
+ }
+
+ // return true if the cell becomes empty (i.e. it did not contain seq_id before the call)
+ bool seq_keep(uint32_t i, llama_seq_id seq_id) {
+ assert(i < pos.size());
+
+ if (seq[i].test(seq_id)) {
+ seq[i].reset();
+ seq[i].set(seq_id);
+
+ return false;
+ }
+
+ if (seq[i].any()) {
+ seq[i].reset();
+ pos[i] = -1;
+
+ used--;
+
+ return true;
+ }
+
+ assert(pos[i] == -1);
+
+ return false;
+ }
+
+ bool seq_has(uint32_t i, llama_seq_id seq_id) const {
+ assert(i < pos.size());
+ assert(seq_id >= 0);
+
+ return seq[i].test(seq_id);
+ }
+
+ // note: call only if the cell is not empty and the seq_id is not in the cell
+ void seq_add(uint32_t i, llama_seq_id seq_id) {
+ assert(i < pos.size());
+ assert(pos[i] != -1);
+ assert(!seq[i].test(seq_id));
+
+ seq[i].set(seq_id);
+ }
+
+ // note: call only if the cell is not empty
+ llama_pos pos_get(uint32_t i) const {
+ assert(i < pos.size());
+ assert(pos[i] != -1);
+
+ return pos[i];
+ }
+
+ // note: call only if the cell is not empty
+ llama_pos get_shift(uint32_t i) const {
+ assert(i < pos.size());
+ assert(pos[i] != -1);
+
+ return shift[i];
+ }
+
+ // check if a cell is not empty and its position is within [p0, p1)
+ bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const {
+ assert(i < pos.size());
+
+ return pos[i] >= p0 && pos[i] < p1;
+ }
+
+ // set the position of an empty cell
+ // does not modify "has_shift"
+ // note: call only if the cell is empty
+ void pos_set(uint32_t i, llama_pos p) {
+ assert(i < pos.size());
+ assert(pos[i] == -1);
+
+ pos[i] = p;
+ used++;
+ }
+
+ // pos[i] = pos[i] + d
+ // sets "has_shift" to true
+ // note: call only if the cell is not empty
+ bool pos_add(uint32_t i, llama_pos d) {
+ assert(i < pos.size());
+ assert(pos[i] != -1);
+
+ pos[i] += d;
+ shift[i] += d;
+
+ has_shift = true;
+
+ if (pos[i] < 0) {
+ pos[i] = -1;
+ seq[i].reset();
+
+ used--;
+
+ return true;
+ }
+
+ return false;
+ }
+
+ // pos[i] = pos[i] / d
+ // sets "has_shift" to true
+ // note: call only if the cell is not empty
+ void pos_div(uint32_t i, int d) {
+ assert(i < pos.size());
+ assert(pos[i] != -1);
+
+ const llama_pos p_old = pos[i];
+
+ pos[i] /= d;
+ shift[i] += p_old - pos[i];
+
+ has_shift = true;
+ }
+
+private:
+ uint32_t used = 0; // used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
+
+ bool has_shift = false;
+
+ std::vector<llama_pos> pos;
+
+ // this array accumulates any applied shifts to the pos array since the last reset_shift() call
+ // this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
+ //
+ // cells.pos_add(x, shift_x);
+ // cells.pos_div(y, shift_y);
+ // ...
+ //
+ // if (cells.has_shift()) {
+ // for (int i = 0; i < n; ++i) {
+ // auto shift_i = cells.get_shift(i);
+ // ...
+ // }
+ // cells.reset_shift();
+ // }
+ //
+ std::vector<llama_pos> shift;
+
+ std::vector<std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>> seq;
+};
+