return kv_self;
}
-bool llama_context::kv_self_update() {
+void llama_context::kv_self_defrag_sched() {
+ if (!memory) {
+ return;
+ }
+
+ memory_force_optimize = true;
+}
+
+bool llama_context::kv_self_update(bool optimize) {
if (!memory) {
return false;
}
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
- if (!kv_self->update(*this)) {
- // no updates have been performed
- return false;
+ {
+ // TODO: remove in the future
+ optimize |= memory_force_optimize;
+ memory_force_optimize = false;
+
+ const auto kv_state = kv_self->init_update(this, optimize);
+ switch (kv_state->get_status()) {
+ case LLAMA_MEMORY_STATUS_SUCCESS:
+ {
+ // noop
+ } break;
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
+ {
+ // no updates need to be performed
+ return false;
+ }
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
+ {
+ LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
+ return false;
+ }
+ }
+
+ if (!kv_state->apply()) {
+ LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
+ }
}
// if the KV cache did any computation, we have to reserve a new worst-case graph
const auto kv_state = kv_self->init_full();
if (!kv_state) {
- throw std::runtime_error("failed to initialize KV cache");
+ throw std::runtime_error("failed to initialize memory state");
}
const uint32_t n_seqs = cparams.n_seq_max;
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
if (!gf) {
- LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
+ LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
}
return true;
n_outputs_all = 1;
}
+ bool did_optimize = false;
+
// handle any pending defrags/shifts
- kv_self_update();
+ kv_self_update(false);
llama_memory_state_ptr kv_state;
- bool did_defrag = false;
-
while (true) {
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
if (!kv_state) {
case LLAMA_MEMORY_STATUS_SUCCESS:
{
} break;
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
+ {
+ LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, kv_state->get_status());
+
+ return -2;
+ }
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
{
- if (!did_defrag) {
- did_defrag = true;
+ if (!did_optimize) {
+ did_optimize = true;
- kv_self->defrag_sched(-1.0f);
- if (kv_self_update()) {
- LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
+ if (kv_self_update(true)) {
+ LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
continue;
}
}
- LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
+ LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
return 1;
}
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
{
+ LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
+
return -2;
}
}
// wait for the computation to finish (automatically done when obtaining the model output)
//synchronize();
- // decide if we need to defrag the kv cache
- if (cparams.defrag_thold > 0.0f) {
- kv_self->defrag_sched(cparams.defrag_thold);
- }
-
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
ggml_backend_sched_reset(sched.get());
// deprecated
void llama_kv_self_update(llama_context * ctx) {
- ctx->kv_self_update();
+ ctx->kv_self_update(false);
}
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
// deprecated
void llama_kv_self_defrag(llama_context * ctx) {
- auto * kv = ctx->get_kv_self();
- if (!kv) {
- return;
- }
-
// force defrag
- kv->defrag_sched(-1.0f);
+ ctx->kv_self_defrag_sched();
}
bool llama_kv_self_can_shift(const llama_context * ctx) {
// return true of the KV cache was updated
// TODO: remove
- bool kv_self_update();
+ bool kv_self_update(bool optimize);
+ void kv_self_defrag_sched();
enum llama_pooling_type pooling_type() const;
std::unique_ptr<llama_memory_i> memory;
+ // TODO: temporary, until the llama_kv_self_defrag() API is removed
+ bool memory_force_optimize = false;
+
// decode output (2-dimensional array: [n_outputs][n_vocab])
size_t logits_size = 0; // capacity (of floats) for logits
float * logits = nullptr;
#include "llama-kv-cache-recurrent.h"
#include "llama-impl.h"
+#include "llama-io.h"
#include "llama-batch.h"
#include "llama-model.h"
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
}
+llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
+ GGML_UNUSED(lctx);
+ GGML_UNUSED(optimize);
+
+ return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
+}
+
bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
// simply remember the full state because it is very small for this type of cache
// TODO: optimize
return success;
}
-bool llama_kv_cache_recurrent::update(llama_context & lctx) {
- GGML_UNUSED(lctx);
- // noop
- return false;
-}
-
-void llama_kv_cache_recurrent::defrag_sched(float thold) {
- GGML_UNUSED(thold);
- // noop
-}
-
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
const uint32_t n_tokens = ubatch.n_tokens;
const uint32_t n_seqs = ubatch.n_seqs;
llama_memory_state_ptr init_full() override;
- bool update(llama_context & lctx) override;
-
- void defrag_sched(float thold) override;
+ llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
bool prepare(const std::vector<llama_ubatch> & ubatches);
assert(heads_base.size() == heads_swa.size());
- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
}
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
}
-bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
- bool res = false;
-
- res = res | kv_base->update(lctx);
- res = res | kv_swa ->update(lctx);
-
- return res;
-}
-
-void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
- kv_base->defrag_sched(thold);
- kv_swa ->defrag_sched(thold);
+llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
+ return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
}
bool llama_kv_cache_unified_iswa::get_can_shift() const {
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
- llama_memory_status status,
- llama_kv_cache_unified_iswa * kv) : status(status) {
- state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base()));
- state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
+ llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
+ state_base = kv->get_base()->init_full();
+ state_swa = kv->get_swa ()->init_full();
+
+ status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
+}
+
+llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
+ llama_kv_cache_unified_iswa * kv,
+ llama_context * lctx,
+ bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
+ state_base = kv->get_base()->init_update(lctx, optimize);
+ state_swa = kv->get_swa ()->init_update(lctx, optimize);
+
+ status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
}
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
- llama_memory_status status,
llama_kv_cache_unified_iswa * kv,
llama_sbatch sbatch,
std::vector<uint32_t> heads_base,
std::vector<uint32_t> heads_swa,
std::vector<llama_ubatch> ubatches)
- : status(status),
- sbatch(std::move(sbatch)),
- ubatches(std::move(ubatches)) {
- // note: here we copy the ubatches. not sure if this is ideal
- state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches));
- state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
- }
+ : status(LLAMA_MEMORY_STATUS_SUCCESS),
+ sbatch(std::move(sbatch)),
+ ubatches(std::move(ubatches)) {
+ // note: here we copy the ubatches. not sure if this is ideal
+ state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
+ state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
+
+ status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
+}
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
return ubatches[i_next];
}
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
- return state_base.get();
+ return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
}
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
- return state_swa.get();
+ return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
}
llama_memory_state_ptr init_full() override;
- bool update(llama_context & lctx) override;
-
- void defrag_sched(float thold) override;
+ llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
// used to create a full-cache state
llama_kv_cache_unified_iswa_state(
- llama_memory_status status,
llama_kv_cache_unified_iswa * kv);
+ // used to create an update state
+ llama_kv_cache_unified_iswa_state(
+ llama_kv_cache_unified_iswa * kv,
+ llama_context * lctx,
+ bool optimize);
+
// used to create a state from a batch
llama_kv_cache_unified_iswa_state(
- llama_memory_status status,
llama_kv_cache_unified_iswa * kv,
llama_sbatch sbatch,
std::vector<uint32_t> heads_base,
const llama_kv_cache_unified_state * get_swa() const;
private:
- const llama_memory_status status;
+ llama_memory_status status;
//llama_kv_cache_unified_iswa * kv;
std::vector<llama_ubatch> ubatches;
- std::unique_ptr<llama_kv_cache_unified_state> state_base;
- std::unique_ptr<llama_kv_cache_unified_state> state_swa;
+ llama_memory_state_ptr state_base;
+ llama_memory_state_ptr state_swa;
};
#include "llama-kv-cache-unified.h"
#include "llama-impl.h"
+#include "llama-io.h"
#include "llama-model.h"
#include "llama-context.h"
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
- return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS,
+ return std::make_unique<llama_kv_cache_unified_state>(
this, std::move(sbatch), std::move(heads), std::move(ubatches));
}
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
- return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
+ return std::make_unique<llama_kv_cache_unified_state>(this);
}
-std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
- std::vector<uint32_t> res;
+llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
+ bool do_shift = get_has_shift();
+
+ defrag_info dinfo;
+
+ // see if we need to defrag
+ {
+ bool do_defrag = optimize;
+
+ const auto thold = lctx->get_cparams().defrag_thold;
+
+ if (!do_defrag && thold > 0.0f) {
+ const auto n_kv = cells.used_max_p1();
+
+ // - do not defrag small contexts (i.e. < 2048 tokens)
+ // - count the padding towards the number of used tokens
+ const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
+
+ if (fragmentation > thold) {
+ LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
+
+ do_defrag = true;
+ }
+ }
+
+ if (do_defrag) {
+ dinfo = defrag_prepare(lctx->graph_max_nodes());
+ }
+ }
+
+ return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
+}
+
+llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
+ llama_kv_cache_unified::ubatch_heads res;
struct state {
uint32_t head_old; // old position of the head, before placing the ubatch
return res;
}
-bool llama_kv_cache_unified::update(llama_context & lctx) {
+bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
bool updated = false;
- auto * sched = lctx.get_sched();
+ auto * sched = lctx->get_sched();
- if (cells.get_has_shift()) {
+ if (do_shift) {
if (!get_can_shift()) {
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
}
if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
ggml_backend_sched_reset(sched);
- auto * gf = lctx.graph_init();
+ auto * gf = lctx->graph_init();
- auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
+ auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
if (!res) {
LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
return updated;
res->set_inputs(nullptr);
- if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
+ if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
return updated;
}
cells.reset_shift();
}
- if (do_defrag) {
+ if (!dinfo.empty()) {
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
- if (defrag_prepare(lctx.graph_max_nodes())) {
- ggml_backend_sched_reset(sched);
-
- auto * gf = lctx.graph_init();
-
- auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
- if (!res) {
- LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
- return updated;
- }
+ // apply moves:
+ {
+ const auto n_kv = dinfo.ids.size();
- if (!ggml_backend_sched_alloc_graph(sched, gf)) {
- LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
- return updated;
- }
+ for (uint32_t i = 0; i < n_kv; ++i) {
+ assert(dinfo.ids[i] <= n_kv);
- res->set_inputs(nullptr);
+ if (dinfo.ids[i] == n_kv) {
+ continue;
+ }
- if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
- LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
- return updated;
+ cells.mv(i, dinfo.ids[i]);
}
- updated = true;
+ // reset the head so we can find the first free slot during the next ubatch
+ head = 0;
}
- do_defrag = false;
- }
+ ggml_backend_sched_reset(sched);
- return updated;
-}
+ auto * gf = lctx->graph_init();
-void llama_kv_cache_unified::defrag_sched(float thold) {
- const auto n_kv = cells.used_max_p1();
+ auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
+ if (!res) {
+ LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
+ return updated;
+ }
+
+ if (!ggml_backend_sched_alloc_graph(sched, gf)) {
+ LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
+ return updated;
+ }
- // - do not defrag small contexts (i.e. < 2048 tokens)
- // - count the padding towards the number of used tokens
- const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
+ res->set_inputs(nullptr);
- // queue defragmentation for next llama_kv_cache_update
- if (fragmentation > thold) {
- LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
+ if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
+ LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
+ return updated;
+ }
- do_defrag = true;
+ updated = true;
}
+
+ return updated;
}
int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
return cells.size();
}
+bool llama_kv_cache_unified::get_has_shift() const {
+ return cells.get_has_shift();
+}
+
uint32_t llama_kv_cache_unified::get_n_kv() const {
return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
}
}
llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
- const llama_cparams & cparams,
- ggml_context * ctx,
- ggml_cgraph * gf) const {
+ const llama_cparams & cparams,
+ ggml_context * ctx,
+ ggml_cgraph * gf,
+ const defrag_info & dinfo) const {
auto res = std::make_unique<llm_graph_result>();
- const auto & ids = defrag_info.ids;
+ const auto & ids = dinfo.ids;
#if 0
// CPU defrag
return res;
}
-bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
+llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
const uint32_t n_layer = layers.size();
const uint32_t n_kv = cells.used_max_p1();
const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
// determine which KV cells to move where
- //
- // cell i moves to ids[i]
- //
- // if ids[i] == i || ids[i] == n_kv, then cell i is not moved
- //
- auto & ids = defrag_info.ids;
+ defrag_info res;
+ auto & ids = res.ids;
- ids.clear();
ids.resize(n_kv, n_kv);
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
// this cell goes to (i0 + nf)
ids[i1] = i0 + nf;
- // move the cell meta data
- cells.mv(i1, i0 + nf);
-
- head = n_used;
-
if (!cont) {
n_moves++;
cont = true;
}
if (n_moves == 0) {
- return false;
+ return {};
}
LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
- return true;
+ return res;
}
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
- llama_memory_status status,
- llama_kv_cache_unified * kv) : status(status), kv(kv) {
- n_kv = kv->get_size();
- head = 0;
- }
+ llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
+ n_kv = kv->get_size();
+ head = 0;
+}
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
- llama_memory_status status,
- llama_kv_cache_unified * kv,
- llama_sbatch sbatch,
- std::vector<uint32_t> heads,
- std::vector<llama_ubatch> ubatches)
- : status(status),
- kv(kv),
- sbatch(std::move(sbatch)),
- heads(std::move(heads)),
- ubatches(std::move(ubatches)) {
+ llama_kv_cache_unified * kv,
+ llama_context * lctx,
+ bool do_shift,
+ defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
+ if (!do_shift && dinfo.empty()) {
+ status = LLAMA_MEMORY_STATUS_NO_UPDATE;
}
+}
+
+llama_kv_cache_unified_state::llama_kv_cache_unified_state(
+ llama_kv_cache_unified * kv,
+ llama_sbatch sbatch,
+ llama_kv_cache_unified::ubatch_heads heads,
+ std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) {
+}
llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
bool llama_kv_cache_unified_state::apply() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+ // no ubatches -> this is a KV cache update
+ if (ubatches.empty()) {
+ kv->update(lctx, do_shift, dinfo);
+
+ return true;
+ }
+
kv->apply_ubatch(heads[i_next], ubatches[i_next]);
n_kv = kv->get_n_kv();
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
+ using ubatch_heads = std::vector<uint32_t>;
+
+ struct defrag_info {
+ bool empty() const {
+ return ids.empty();
+ }
+
+ // contains information about which cell moves where:
+ // - cell i moves to ids[i]
+ // - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
+ std::vector<uint32_t> ids;
+ };
+
llama_kv_cache_unified(
const llama_model & model,
layer_filter_cb && filter,
llama_memory_state_ptr init_full() override;
- bool update(llama_context & lctx) override;
-
- void defrag_sched(float thold) override;
+ llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
uint32_t get_size() const;
+ bool get_has_shift() const;
+
//
// graph_build API
//
// find places for the provided ubatches in the cache, returns the head locations
// return empty vector on failure
- std::vector<uint32_t> prepare(const std::vector<llama_ubatch> & ubatches);
+ ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
+
+ bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
// return the cell position where we can insert the ubatch
// return -1 on failure to find a contiguous slot of kv cells
ggml_tensor * v;
};
- bool do_defrag = false;
- bool v_trans = true; // the value tensor is transposed
+ bool v_trans = true; // the value tensor is transposed
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
// model layer id -> KV cache layer id
std::unordered_map<int32_t, int32_t> map_layer_ids;
- // defrag
- struct {
- std::vector<uint32_t> ids;
- } defrag_info;
-
- // return true if cells have been moved
- bool defrag_prepare(int32_t n_max_nodes);
+ // return non-empty vector if cells have been moved
+ defrag_info defrag_prepare(int32_t n_max_nodes) const;
size_t total_size() const;
llm_graph_result_ptr build_graph_defrag(
const llama_cparams & cparams,
ggml_context * ctx,
- ggml_cgraph * gf) const;
+ ggml_cgraph * gf,
+ const defrag_info & dinfo) const;
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
class llama_kv_cache_unified_state : public llama_memory_state_i {
public:
+ // some shorthands
+ using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
+ using defrag_info = llama_kv_cache_unified::defrag_info;
+
// used for errors
llama_kv_cache_unified_state(llama_memory_status status);
// used to create a full-cache state
llama_kv_cache_unified_state(
- llama_memory_status status,
llama_kv_cache_unified * kv);
- // used to create a state from a batch
+ // used to create an update state
+ llama_kv_cache_unified_state(
+ llama_kv_cache_unified * kv,
+ llama_context * lctx,
+ bool do_shift,
+ defrag_info dinfo);
+
+ // used to create a decode state from a batch
llama_kv_cache_unified_state(
- llama_memory_status status,
llama_kv_cache_unified * kv,
llama_sbatch sbatch,
- std::vector<uint32_t> heads,
+ ubatch_heads heads,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_state();
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
private:
- const llama_memory_status status;
+ llama_memory_status status;
llama_kv_cache_unified * kv;
+ llama_context * lctx;
+
+ //
+ // update state
+ //
+
+ bool do_shift = false;
+
+ defrag_info dinfo;
+
+ //
+ // batch processing state
+ //
llama_sbatch sbatch;
// the index of the next ubatch to process
size_t i_next = 0;
- std::vector<uint32_t> heads;
+ ubatch_heads heads;
+
std::vector<llama_ubatch> ubatches;
//
#pragma once
#include "llama.h"
-#include "llama-io.h"
#include "llama-memory.h"
+class llama_io_write_i;
+class llama_io_read_i;
+
struct llama_kv_cache : public llama_memory_i {
virtual ~llama_kv_cache() = default;
+ // TODO: move the init_ interfaces to llama_memory_i
+
// split the input batch into a set of ubatches and verify that they can fit into the cache
// return a state object containing the ubatches and KV cache state required to process them
// check the llama_memory_state_i::get_status() for the result
// simulate full cache, used for allocating worst-case compute buffers
virtual llama_memory_state_ptr init_full() = 0;
- // process any pending defrag/shift/etc. operations
- // optionally call once before processing a new batch
- // return true if any operations were performed
- virtual bool update(llama_context & lctx) = 0;
-
- // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
- // TODO: change to
- // llama_memory_state_ptr init_defrag(float thold) = 0;
- //
- virtual void defrag_sched(float thold) = 0;
+ // prepare for any pending memory updates, such as shifts, defrags, etc.
+ // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
+ virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
// getters
virtual bool get_can_shift() const = 0;
#include "llama-memory.h"
+
+llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1) {
+ bool has_update = false;
+
+ switch (s0) {
+ case LLAMA_MEMORY_STATUS_SUCCESS:
+ {
+ has_update = true;
+ break;
+ }
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
+ {
+ break;
+ }
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
+ {
+ return s0;
+ }
+ }
+
+ switch (s1) {
+ case LLAMA_MEMORY_STATUS_SUCCESS:
+ {
+ has_update = true;
+ break;
+ }
+ case LLAMA_MEMORY_STATUS_NO_UPDATE:
+ {
+ break;
+ }
+ case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
+ case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
+ {
+ return s1;
+ }
+ }
+
+ // if either status has an update, then the combined status has an update
+ return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE;
+}
virtual bool get_can_edit() const = 0;
};
+using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
+
enum llama_memory_status {
LLAMA_MEMORY_STATUS_SUCCESS = 0,
+ LLAMA_MEMORY_STATUS_NO_UPDATE,
LLAMA_MEMORY_STATUS_FAILED_PREPARE,
LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
};
+// helper function for combining the status of two memory states
+// useful for implementing hybrid memory types (e.g. iSWA)
+llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
+
// the interface for managing the memory state during batch processing
// this interface is implemented per memory type. see:
// - llama_kv_cache_unified_state
// get the current ubatch
virtual const llama_ubatch & get_ubatch() const = 0;
- // get the status of the memory state
+ // get the status of the memory state - used for error handling and checking if any updates would be applied
virtual llama_memory_status get_status() const = 0;
};