common_params params;
+ params.n_predict = 128;
+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) {
return 1;
}
const int64_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;
- // TODO: remove this stuff
- class batch_guard {
- public:
- batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) {
- }
-
- ~batch_guard() {
- if (!is_done) {
- kv_slot_restorer.restore();
- }
- }
-
- void done() {
- is_done = true;
- }
-
- void save(const llama_kv_cache_slot_info & slot_info) {
- kv_slot_restorer.save(slot_info);
- }
-
- private:
- bool is_done = false;
-
- llama_kv_slot_restorer kv_slot_restorer;
- };
-
- batch_guard bg(*kv_self);
+ llama_kv_cache_guard kv_guard(kv_self.get());
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
return -2;
};
+ // handle any pending defrags/shifts
+ kv_self_update();
+
int64_t n_outputs_prev = 0;
while (sbatch.n_tokens > 0) {
// find KV slot
{
- kv_self_update();
+ if (!kv_self->find_slot(ubatch)) {
+ LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
- // if we have enough unused cells before the current head ->
- // better to start searching from the beginning of the cache, hoping to fill it
- if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) {
- kv_self->head = 0;
+ return 1;
}
- const auto slot_info = kv_self->find_slot(ubatch);
- if (!slot_info) {
- LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
- return -3;
- }
-
- bg.save(slot_info);
-
if (!kv_self->recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
}
}
- // update the kv ring buffer
- {
- kv_self->head += ubatch.n_tokens;
-
- // Ensure kv cache head points to a valid index.
- if (kv_self->head >= kv_self->size) {
- kv_self->head = 0;
- }
- }
-
// plot the computation graph in dot format (for debugging purposes)
//if (n_past%100 == 0) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
}
// finalize the batch processing
- bg.done();
+ kv_guard.commit();
// set output mappings
{
#include <map>
#include <stdexcept>
-static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
-
llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
}
return false;
}
}
+
+ return true;
}
for (uint32_t i = 0; i < size; ++i) {
}
}
+void llama_kv_cache_unified::restore() {
+ if (pending.ranges.empty()) {
+ return;
+ }
+
+ // TODO: tmp - move to llama_kv_cache_recurrent
+ if (recurrent) {
+ seq_rm(-1, -1, -1);
+ return;
+ }
+
+ uint32_t new_head = size;
+
+ for (auto & range : pending.ranges) {
+ for (uint32_t i = range.c0; i < range.c1; ++i) {
+ cells[i].seq_id.clear();
+
+ // keep count of the number of used cells
+ if (cells[i].pos >= 0) {
+ used--;
+ }
+
+ cells[i].pos = -1;
+ cells[i].src = -1;
+ }
+
+ new_head = std::min(new_head, range.c0);
+ }
+
+ if (new_head != size && new_head < head) {
+ head = new_head;
+ }
+}
+
+void llama_kv_cache_unified::commit() {
+ if (pending.ranges.empty()) {
+ LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
+ __func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
+ return;
+ }
+
+ pending.ranges.clear();
+}
+
bool llama_kv_cache_unified::get_can_shift() const {
return can_shift;
}
-llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
+bool llama_kv_cache_unified::find_slot(
const llama_ubatch & ubatch) {
const uint32_t n_tokens = ubatch.n_tokens;
const uint32_t n_seqs = ubatch.n_seqs;
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
+ // if we have enough unused cells before the current head ->
+ // better to start searching from the beginning of the cache, hoping to fill it
+ if (head > used + 2*ubatch.n_tokens) {
+ head = 0;
+ }
+
if (recurrent) {
// For recurrent state architectures (like Mamba or RWKV),
// each cache cell can store the state for a whole sequence.
// too big seq_id
// TODO: would it be possible to resize the cache instead?
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
- return llama_kv_cache_slot_info_failed;
+ return false;
}
if (j > 0) {
llama_kv_cell & seq = cells[seq_id];
[](const llama_kv_cell& cell){ return !cell.is_empty(); });
// sanity check
- return llama_kv_cache_slot_info(n >= n_seqs);
+ return n >= n_seqs;
}
// otherwise, one cell per token.
if (n_tokens > size) {
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
- return llama_kv_cache_slot_info_failed;
+ return false;
}
uint32_t n_tested = 0;
if (n_tested >= size) {
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
- return llama_kv_cache_slot_info_failed;
+ return false;
}
}
used += n_tokens;
- return llama_kv_cache_slot_info(head, head + n_tokens);
+ pending.ranges.push_back({head, head + n_tokens});
+
+ return true;
}
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false;
}
+ commit();
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
// Assume that this is one contiguous block of cells
struct llama_kv_cache : public llama_memory_i {
using llama_memory_i::llama_memory_i;
+ virtual void restore() = 0; // call if batch processing fails - restores the cache state
+ virtual void commit() = 0; // call after successful batch processing - clears any pending state
+
virtual int32_t get_n_tokens() const = 0;
virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
bool get_can_edit() const override { return get_can_shift(); }
};
+struct llama_kv_cache_guard {
+ llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
+
+ ~llama_kv_cache_guard() {
+ kv->restore();
+ }
+
+ void commit() {
+ kv->commit();
+ }
+
+private:
+ llama_kv_cache * kv;
+};
+
struct llama_kv_cell {
llama_pos pos = -1;
- llama_pos delta = 0;
+ llama_pos delta = 0;
int32_t src = -1; // used by recurrent state models to copy states
int32_t tail = -1;
}
};
-// a structure holds information about the slot found in llama_kv_cache_find_slot
-struct llama_kv_cache_slot_info {
- std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
- bool found = false; // the slot was found
-
- explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
- llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
-
- operator bool() const { return found; }
-};
-
// ring-buffer of cached KV data
// TODO: pimpl
// TODO: add notion of max sequences
void clear() override;
void defrag() override;
+ virtual void restore() override;
+ virtual void commit() override;
+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
// find an empty slot of size "n_tokens" in the cache
// updates the cache head
- // returns a structure holding information about the slot found
// Note: On success, it's important that cache.head points
// to the first cell of the slot.
- llama_kv_cache_slot_info find_slot(const llama_ubatch & batch);
+ bool find_slot(const llama_ubatch & batch);
// TODO: maybe not needed
uint32_t get_padding(const llama_cparams & cparams) const;
// return true if cells have been moved
bool defrag_prepare(int32_t n_max_nodes);
- // state save/load
+ // commit/restore cache
+
+ struct slot_range {
+ uint32_t c0 = 0; // note: these are cell indices, not sequence positions
+ uint32_t c1 = 0;
+ };
+
+ // pending cell updates that are not yet committed
+ struct {
+ std::vector<slot_range> ranges;
+ } pending;
+
+ // state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
// using llama_kv_cache_unified::llama_kv_cache_unified;
//};
-//
-// kv cache restore
-//
-
-// saves the kv_cache state for future recovery.
-// used to rollback llama_kv_cache_find_slot changes.
-struct llama_kv_slot_restorer {
- struct llama_kv_cache_state {
- uint32_t head = 0;
- uint32_t n = 0;
- } old_state;
-
- // for non-recurrent models only
- // list of slots to restore
- std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
-
- bool do_restore = false;
-
- llama_kv_cache_unified & cache;
-
- explicit llama_kv_slot_restorer(llama_kv_cache_unified & cache) : cache(cache) {
- old_state.head = cache.head;
- old_state.n = cache.n;
- }
-
- // saves a slot information for future restoration
- void save(const llama_kv_cache_slot_info & slot) {
- if (slot) {
- do_restore = true;
- if (slot.boundaries.first != slot.boundaries.second) {
- slot_boundaries.push_back(slot.boundaries);
- }
- }
- }
-
- // must be explicitly called to restore the kv_cache state
- // and rollback changes from all llama_kv_cache_find_slot calls
- void restore() {
- if (do_restore) {
- cache.head = old_state.head;
- cache.n = old_state.n;
-
- if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
- cache.seq_rm(-1, -1, -1);
- } else {
- for (auto & slot : slot_boundaries) {
- cache.seq_rm(-1, slot.first, slot.second);
- }
- }
- }
- }
-};
-
// TODO: maybe become part of the public llama_kv_cache in the future
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);