}
// reserve worst-case graph
- if (!hparams.vocab_only && memory) {
+ if (!hparams.vocab_only) {
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
int n_splits_tg = -1;
int n_nodes_tg = -1;
- // simulate full KV cache
-
- const auto mctx = memory->init_full();
- if (!mctx) {
- throw std::runtime_error("failed to initialize KV cache");
+ llama_memory_context_ptr mctx;
+ if (memory) {
+ LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
+ mctx = memory->init_full();
+ if (!mctx) {
+ throw std::runtime_error("failed to initialize memory module");
+ }
}
cross.v_embd.clear();
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
if (!res) {
- // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
+ // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module
llama_pos pos_min[LLAMA_MAX_SEQ];
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
pos_min[s] = std::numeric_limits<llama_pos>::max();
continue;
}
- LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
+ LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
memory->seq_rm(s, pos_min[s], -1);
}
}
if (memory != nullptr) {
- LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
+ LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
memory->state_write(io);
}
}
if (memory) {
- LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
+ LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
memory->state_read(io);
}