struct llama_model;
struct llama_context;
struct llama_sampler;
- struct llama_kv_cache;
+
+ typedef struct llama_memory_i * llama_memory_t;
+
+ struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
typedef int32_t llama_pos;
typedef int32_t llama_token;
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
- LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx);
+ LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
+ DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
+
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
int32_t il_end);
//
- // KV cache
+ // Memory
+ //
+
+ // Clear the memory contents
+ LLAMA_API void llama_memory_clear(llama_memory_t mem);
+
+ // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
+ // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
+ // seq_id < 0 : match any sequence
+ // p0 < 0 : [0, p1]
+ // p1 < 0 : [p0, inf)
+ LLAMA_API bool llama_memory_seq_rm(
+ llama_memory_t mem,
+ llama_seq_id seq_id,
+ llama_pos p0,
+ llama_pos p1);
+
+ // Copy all tokens that belong to the specified sequence to another sequence
+ // p0 < 0 : [0, p1]
+ // p1 < 0 : [p0, inf)
+ LLAMA_API void llama_memory_seq_cp(
+ llama_memory_t mem,
+ llama_seq_id seq_id_src,
+ llama_seq_id seq_id_dst,
+ llama_pos p0,
+ llama_pos p1);
+
+ // Removes all tokens that do not belong to the specified sequence
+ LLAMA_API void llama_memory_seq_keep(
+ llama_memory_t mem,
+ llama_seq_id seq_id);
+
+ // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
+ // p0 < 0 : [0, p1]
+ // p1 < 0 : [p0, inf)
+ LLAMA_API void llama_memory_seq_add(
+ llama_memory_t mem,
+ llama_seq_id seq_id,
+ llama_pos p0,
+ llama_pos p1,
+ llama_pos delta);
+
+ // Integer division of the positions by factor of `d > 1`
+ // p0 < 0 : [0, p1]
+ // p1 < 0 : [p0, inf)
+ LLAMA_API void llama_memory_seq_div(
+ llama_memory_t mem,
+ llama_seq_id seq_id,
+ llama_pos p0,
+ llama_pos p1,
+ int d);
+
+ // Returns the smallest position present in the memory for the specified sequence
+ // This is typically non-zero only for SWA caches
+ // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
+ // Return -1 if the sequence is empty
+ LLAMA_API llama_pos llama_memory_seq_pos_min(
+ llama_memory_t mem,
+ llama_seq_id seq_id);
+
+ // Returns the largest position present in the memory for the specified sequence
+ // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
+ // Return -1 if the sequence is empty
+ LLAMA_API llama_pos llama_memory_seq_pos_max(
+ llama_memory_t mem,
+ llama_seq_id seq_id);
+
+ // Check if the memory supports shifting
+ LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
+
+ //
+ // KV cache for self-attention (TODO: deprecate in favor of llama_memory)
//
// Returns the number of tokens in the KV cache (slow, use only for debug)
// Clear the KV cache - both cell info is erased and KV data is zeroed
LLAMA_API void llama_kv_self_clear(
- struct llama_context * ctx);
+ struct llama_context * ctx);
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
// Defragment the KV cache
// This will be applied:
// - lazily on next llama_decode()
- LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx),
+ DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx),
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
// Check if the context supports KV cache shifting
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
- LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx),
+ DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx),
"simply remove this call, updates are applied lazily on the next llama_decode()");
//
//
// Returns the *actual* size in bytes of the state
- // (logits, embedding and kv_cache)
+ // (logits, embedding and memory)
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
size_t n_token_count),
"use llama_state_save_file instead");
- // Get the exact size needed to copy the KV cache of a single sequence
+ // Get the exact size needed to copy the state of a single sequence
LLAMA_API size_t llama_state_seq_get_size(
struct llama_context * ctx,
llama_seq_id seq_id);
- // Copy the KV cache of a single sequence into the specified buffer
+ // Copy the state of a single sequence into the specified buffer
LLAMA_API size_t llama_state_seq_get_data(
struct llama_context * ctx,
uint8_t * dst,
// For encode-decoder contexts, processes the batch using the encoder.
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
// 0 - success
- // < 0 - error. the KV cache state is restored to the state before this call
+ // < 0 - error. the memory state is restored to the state before this call
LLAMA_API int32_t llama_encode(
struct llama_context * ctx,
struct llama_batch batch);
// Process a batch of tokens.
- // Requires KV cache.
+ // Requires the context to have a memory.
// For encode-decoder contexts, processes the batch using the decoder.
// Positive return values does not mean a fatal error, but rather a warning.
- // Upon non-zero return values, the KV cache state is restored to the state before this call
+ // Upon non-zero return values, the memory state is restored to the state before this call
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// 2 - aborted
llama-hparams.cpp
llama-impl.cpp
llama-io.cpp
- llama-kv-cache.cpp
llama-kv-cache-unified.cpp
llama-kv-cache-unified-iswa.cpp
llama-kv-cache-recurrent.cpp
#include "llama-impl.h"
#include "llama-io.h"
+#include "llama-memory.h"
#include "llama-mmap.h"
#include "llama-model.h"
-#include "llama-kv-cache.h"
#include <cinttypes>
#include <cstring>
int n_nodes_tg = -1;
// simulate full KV cache
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
- const auto kv_state = kv_self->init_full();
- if (!kv_state) {
+ const auto mstate = memory->init_full();
+ if (!mstate) {
throw std::runtime_error("failed to initialize KV cache");
}
// reserve pp graph first so that buffers are only allocated once
{
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
}
// reserve with tg graph to get the number of splits and nodes
{
- auto * gf = graph_reserve(1, 1, 1, kv_state.get());
+ auto * gf = graph_reserve(1, 1, 1, mstate.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute tg buffers");
}
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
{
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
}
return cparams.n_threads_batch;
}
-llama_kv_cache * llama_context::get_kv_self() {
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
- return kv_self;
-}
-
-const llama_kv_cache * llama_context::get_kv_self() const {
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
- return kv_self;
+llama_memory_t llama_context::get_memory() const {
+ return memory.get();
}
void llama_context::kv_self_defrag_sched() {
return false;
}
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-
{
// TODO: remove in the future
optimize |= memory_force_optimize;
memory_force_optimize = false;
- const auto kv_state = kv_self->init_update(this, optimize);
- switch (kv_state->get_status()) {
+ const auto mstate = memory->init_update(this, optimize);
+ switch (mstate->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS:
{
// noop
}
}
- if (!kv_state->apply()) {
+ if (!mstate->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
}
}
- // if the KV cache did any computation, we have to reserve a new worst-case graph
- const auto kv_state = kv_self->init_full();
- if (!kv_state) {
- throw std::runtime_error("failed to initialize memory state");
- }
+ // if the memory module did any computation, we have to reserve a new worst-case graph
+ {
+ const auto mstate = memory->init_full();
+ if (!mstate) {
+ throw std::runtime_error("failed to initialize memory state");
+ }
- const uint32_t n_seqs = cparams.n_seq_max;
- const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
+ const uint32_t n_seqs = cparams.n_seq_max;
+ const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
- auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
- if (!gf) {
- LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
+ auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
+ if (!gf) {
+ LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
+ }
}
return true;
}
}
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-
// temporary allocate memory for the input batch if needed
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
+ llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1);
const llama_batch & batch = batch_allocr.batch;
// handle any pending defrags/shifts
kv_self_update(false);
- llama_memory_state_ptr kv_state;
+ llama_memory_state_ptr mstate;
while (true) {
- kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
- if (!kv_state) {
+ mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
+ if (!mstate) {
return -2;
}
- switch (kv_state->get_status()) {
+ switch (mstate->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS:
{
} break;
case LLAMA_MEMORY_STATUS_NO_UPDATE:
{
- LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, kv_state->get_status());
+ LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
return -2;
}
int64_t n_outputs_prev = 0;
do {
- const auto & ubatch = kv_state->get_ubatch();
+ const auto & ubatch = mstate->get_ubatch();
// count the outputs in this u_batch
{
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
ggml_status status;
- const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status);
+ const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
- llama_kv_self_seq_rm(this, s, pos_min[s], -1);
+ memory->seq_rm(s, pos_min[s], -1);
}
switch (status) {
}
n_outputs_prev += n_outputs;
- } while (kv_state->next());
+ } while (mstate->next());
// set to total number of outputs in the batch, for use in llama_get_logits_ith
n_outputs = n_outputs_all;
{
bool sorted_output = true;
- auto & out_ids = kv_state->out_ids();
+ auto & out_ids = mstate->out_ids();
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
}
}
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-
- if (kv_self != nullptr) {
+ if (memory != nullptr) {
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
- kv_self->state_write(io);
+ memory->state_write(io);
}
return io.n_bytes();
if (memory) {
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-
- kv_self->state_read(io);
+ memory->state_read(io);
}
return io.n_bytes();
GGML_UNUSED(seq_id);
if (memory) {
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-
- kv_self->state_write(io, seq_id);
+ memory->state_write(io, seq_id);
}
return io.n_bytes();
GGML_UNUSED(seq_id);
if (memory) {
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-
- kv_self->state_read(io, seq_id);
+ memory->state_read(io, seq_id);
}
return io.n_bytes();
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
- llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-
- kv_self->clear();
+ memory->clear();
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
batch.n_tokens = n_batch;
int64_t n_outputs_all = n_tokens_all;
- auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
- if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
+ auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
+ if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
break;
}
uint32_t pos_batch = 0;
do {
- const auto & ubatch = kv_state->get_ubatch();
+ const auto & ubatch = mstate->get_ubatch();
n_outputs = ubatch.n_tokens;
- if (!kv_state->apply()) {
+ if (!mstate->apply()) {
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
break;
}
auto * gf = graph_init();
- auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get());
+ auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
struct ggml_context * ctx_compute_opt;
{
ggml_free(ctx_compute_opt);
pos_batch += ubatch.n_tokens;
- } while (kv_state->next());
+ } while (mstate->next());
}
}
return &ctx->get_model();
}
+// deprecated
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
- return ctx->get_kv_self();
+ return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
}
// deprecated
return res ? 0 : -1;
}
+//
+// memory
+//
+
+llama_memory_t llama_get_memory(const struct llama_context * ctx) {
+ return ctx->get_memory();
+}
+
+void llama_memory_clear(llama_memory_t mem) {
+ mem->clear();
+}
+
+bool llama_memory_seq_rm(
+ llama_memory_t mem,
+ llama_seq_id seq_id,
+ llama_pos p0,
+ llama_pos p1) {
+ return mem->seq_rm(seq_id, p0, p1);
+}
+
+void llama_memory_seq_cp(
+ llama_memory_t mem,
+ llama_seq_id seq_id_src,
+ llama_seq_id seq_id_dst,
+ llama_pos p0,
+ llama_pos p1) {
+ mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+}
+
+void llama_memory_seq_keep(
+ llama_memory_t mem,
+ llama_seq_id seq_id) {
+ mem->seq_keep(seq_id);
+}
+
+void llama_memory_seq_add(
+ llama_memory_t mem,
+ llama_seq_id seq_id,
+ llama_pos p0,
+ llama_pos p1,
+ llama_pos delta) {
+ mem->seq_add(seq_id, p0, p1, delta);
+}
+
+void llama_memory_seq_div(
+ llama_memory_t mem,
+ llama_seq_id seq_id,
+ llama_pos p0,
+ llama_pos p1,
+ int d) {
+ mem->seq_div(seq_id, p0, p1, d);
+}
+
+llama_pos llama_memory_seq_pos_min(
+ llama_memory_t mem,
+ llama_seq_id seq_id) {
+ return mem->seq_pos_min(seq_id);
+}
+
+llama_pos llama_memory_seq_pos_max(
+ llama_memory_t mem,
+ llama_seq_id seq_id) {
+ return mem->seq_pos_max(seq_id);
+}
+
+bool llama_memory_can_shift(llama_memory_t mem) {
+ return mem->get_can_shift();
+}
+
//
// kv cache
//
// deprecated
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
- const auto * kv = ctx->get_kv_self();
+ const auto * kv = llama_get_memory(ctx);
if (!kv) {
return 0;
}
// deprecated
// note: this is the same as above - will be removed anyway, so it's ok
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
- const auto * kv = ctx->get_kv_self();
+ const auto * kv = llama_get_memory(ctx);
if (!kv) {
return 0;
}
}
void llama_kv_self_clear(llama_context * ctx) {
- auto * kv = ctx->get_kv_self();
+ auto * kv = llama_get_memory(ctx);
if (!kv) {
return;
}
- kv->clear();
+ llama_memory_clear(kv);
}
bool llama_kv_self_seq_rm(
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
- auto * kv = ctx->get_kv_self();
+ auto * kv = llama_get_memory(ctx);
if (!kv) {
return true;
}
- return kv->seq_rm(seq_id, p0, p1);
+ return llama_memory_seq_rm(kv, seq_id, p0, p1);
}
void llama_kv_self_seq_cp(
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
- auto * kv = ctx->get_kv_self();
+ auto * kv = llama_get_memory(ctx);
if (!kv) {
return;
}
- kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+ llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
}
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
- auto * kv = ctx->get_kv_self();
+ auto * kv = llama_get_memory(ctx);
if (!kv) {
return;
}
- kv->seq_keep(seq_id);
+ llama_memory_seq_keep(kv, seq_id);
}
void llama_kv_self_seq_add(
llama_pos p0,
llama_pos p1,
llama_pos delta) {
- auto * kv = ctx->get_kv_self();
+ auto * kv = llama_get_memory(ctx);
if (!kv) {
return;
}
- kv->seq_add(seq_id, p0, p1, delta);
+ llama_memory_seq_add(kv, seq_id, p0, p1, delta);
}
void llama_kv_self_seq_div(
llama_pos p0,
llama_pos p1,
int d) {
- auto * kv = ctx->get_kv_self();
+ auto * kv = llama_get_memory(ctx);
if (!kv) {
return;
}
- kv->seq_div(seq_id, p0, p1, d);
+ llama_memory_seq_div(kv, seq_id, p0, p1, d);
}
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
- const auto * kv = ctx->get_kv_self();
+ auto * kv = llama_get_memory(ctx);
if (!kv) {
return -1;
}
- return kv->seq_pos_min(seq_id);
+ return llama_memory_seq_pos_min(kv, seq_id);
}
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
- const auto * kv = ctx->get_kv_self();
+ auto * kv = llama_get_memory(ctx);
if (!kv) {
return -1;
}
- return kv->seq_pos_max(seq_id);
+ return llama_memory_seq_pos_max(kv, seq_id);
}
// deprecated
}
bool llama_kv_self_can_shift(const llama_context * ctx) {
- const auto * kv = ctx->get_kv_self();
+ auto * kv = llama_get_memory(ctx);
if (!kv) {
return false;
}
- return kv->get_can_shift();
+ return llama_memory_can_shift(kv);
}
// llama state API
#include <vector>
struct llama_model;
-struct llama_kv_cache;
class llama_io_read_i;
class llama_io_write_i;
-class llama_memory_i;
-class llama_memory_state_i;
+struct llama_memory_i;
+struct llama_memory_state_i;
struct llama_context {
// init scheduler and compute buffers, reserve worst-case graphs
uint32_t n_threads() const;
uint32_t n_threads_batch() const;
- llama_kv_cache * get_kv_self();
- const llama_kv_cache * get_kv_self() const;
+ llama_memory_t get_memory() const;
// return true of the KV cache was updated
// TODO: remove
struct llama_ubatch;
struct llama_cparams;
-class llama_memory_state_i;
+struct llama_memory_state_i;
class llama_kv_cache_unified_state;
class llama_kv_cache_unified_iswa_state;
#include "llama-batch.h"
#include "llama-graph.h"
-#include "llama-kv-cache.h"
+#include "llama-memory.h"
#include <set>
#include <vector>
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
-class llama_kv_cache_recurrent : public llama_kv_cache {
+class llama_kv_cache_recurrent : public llama_memory_i {
public:
llama_kv_cache_recurrent(
const llama_model & model,
// llama_memory_i
//
+ llama_memory_state_ptr init_batch(
+ const llama_batch & batch,
+ uint32_t n_ubatch,
+ bool embd_pooled,
+ bool logits_all) override;
+
+ llama_memory_state_ptr init_full() override;
+
+ llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
+
void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
- //
- // llama_kv_cache
- //
-
- llama_memory_state_ptr init_batch(
- const llama_batch & batch,
- uint32_t n_ubatch,
- bool embd_pooled,
- bool logits_all) override;
-
- llama_memory_state_ptr init_full() override;
-
- llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
-
bool prepare(const std::vector<llama_ubatch> & ubatches);
// find a contiguous slot of kv cells and emplace the ubatch there
// utilizes two instances of llama_kv_cache_unified
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
-class llama_kv_cache_unified_iswa : public llama_kv_cache {
+class llama_kv_cache_unified_iswa : public llama_memory_i {
public:
llama_kv_cache_unified_iswa(
const llama_model & model,
// llama_memory_i
//
- void clear() override;
-
- bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
- void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
- void seq_keep(llama_seq_id seq_id) override;
- void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
- void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
-
- llama_pos seq_pos_min(llama_seq_id seq_id) const override;
- llama_pos seq_pos_max(llama_seq_id seq_id) const override;
-
- //
- // llama_kv_cache
- //
-
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool get_can_shift() const override;
+ void clear() override;
+
+ bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
+ void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+ void seq_keep(llama_seq_id seq_id) override;
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
+ void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
+
+ llama_pos seq_pos_min(llama_seq_id seq_id) const override;
+ llama_pos seq_pos_max(llama_seq_id seq_id) const override;
+
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
#include "llama-batch.h"
#include "llama-graph.h"
-#include "llama-kv-cache.h"
#include "llama-kv-cells.h"
+#include "llama-memory.h"
#include <unordered_map>
#include <vector>
// llama_kv_cache_unified
//
-class llama_kv_cache_unified : public llama_kv_cache {
+class llama_kv_cache_unified : public llama_memory_i {
public:
static uint32_t get_padding(const llama_cparams & cparams);
// llama_memory_i
//
- void clear() override;
-
- bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
- void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
- void seq_keep(llama_seq_id seq_id) override;
- void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
- void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
-
- llama_pos seq_pos_min(llama_seq_id seq_id) const override;
- llama_pos seq_pos_max(llama_seq_id seq_id) const override;
-
- //
- // llama_kv_cache
- //
-
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool get_can_shift() const override;
+ void clear() override;
+
+ bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
+ void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+ void seq_keep(llama_seq_id seq_id) override;
+ void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
+ void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
+
+ llama_pos seq_pos_min(llama_seq_id seq_id) const override;
+ llama_pos seq_pos_max(llama_seq_id seq_id) const override;
+
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
+++ /dev/null
-#include "llama-kv-cache.h"
+++ /dev/null
-#pragma once
-
-#include "llama.h"
-#include "llama-memory.h"
-
-class llama_io_write_i;
-class llama_io_read_i;
-
-struct llama_kv_cache : public llama_memory_i {
- virtual ~llama_kv_cache() = default;
-
- // TODO: move the init_ interfaces to llama_memory_i
-
- // split the input batch into a set of ubatches and verify that they can fit into the cache
- // return a state object containing the ubatches and KV cache state required to process them
- // check the llama_memory_state_i::get_status() for the result
- virtual llama_memory_state_ptr init_batch(
- const llama_batch & batch,
- uint32_t n_ubatch,
- bool embd_pooled,
- bool logits_all) = 0;
-
- // simulate full cache, used for allocating worst-case compute buffers
- virtual llama_memory_state_ptr init_full() = 0;
-
- // prepare for any pending memory updates, such as shifts, defrags, etc.
- // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
- virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
-
- // getters
- virtual bool get_can_shift() const = 0;
-
- bool get_can_edit() const override { return get_can_shift(); }
-
- //
- // state write/read
- //
-
- virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
- virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
-};
struct llama_ubatch;
+class llama_io_write_i;
+class llama_io_read_i;
+
struct llama_memory_params {
// kv cache
ggml_type type_k;
bool swa_full;
};
-// general concept of LLM memory
-// the KV cache is a type of LLM memory, but there can be other types
-class llama_memory_i {
-public:
- virtual ~llama_memory_i() = default;
-
- virtual void clear() = 0;
-
- virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
- virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
- virtual void seq_keep(llama_seq_id seq_id) = 0;
- virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
- virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
-
- virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
- virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
-
- virtual bool get_can_edit() const = 0;
-};
-
-using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
-
enum llama_memory_status {
LLAMA_MEMORY_STATUS_SUCCESS = 0,
LLAMA_MEMORY_STATUS_NO_UPDATE,
// the only method that can mutate the memory and the memory state is llama_memory_i::apply()
//
// TODO: rename to llama_memory_context_i ?
-class llama_memory_state_i {
-public:
+struct llama_memory_state_i {
virtual ~llama_memory_state_i() = default;
// consume the current ubatch from the state and proceed to the next one
};
using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
+
+// general concept of LLM memory
+// the KV cache is a type of LLM memory, but there can be other types
+struct llama_memory_i {
+ virtual ~llama_memory_i() = default;
+
+ // split the input batch into a set of ubatches and verify that they can fit into the cache
+ // return a state object containing the ubatches and KV cache state required to process them
+ // check the llama_memory_state_i::get_status() for the result
+ virtual llama_memory_state_ptr init_batch(
+ const llama_batch & batch,
+ uint32_t n_ubatch,
+ bool embd_pooled,
+ bool logits_all) = 0;
+
+ // simulate full cache, used for allocating worst-case compute buffers
+ virtual llama_memory_state_ptr init_full() = 0;
+
+ // prepare for any pending memory updates, such as shifts, defrags, etc.
+ // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
+ virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
+
+ // getters
+ virtual bool get_can_shift() const = 0;
+
+ //
+ // ops
+ //
+
+ virtual void clear() = 0;
+
+ virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
+ virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
+ virtual void seq_keep(llama_seq_id seq_id) = 0;
+ virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
+ virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
+
+ virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
+ virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
+
+ //
+ // state write/read
+ //
+
+ virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
+ virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
+};
+
+using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
+
+// TODO: temporary until the llama_kv_cache is removed from the public API
+struct llama_kv_cache : public llama_memory_i {
+ virtual ~llama_kv_cache() = default;
+};