}
}
+bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
+ const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
+
+ this->mctx = mctx;
+
+ bool res = true;
+
+ res &= s_copy->ne[0] == mctx->get_n_rs();
+
+ res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
+ res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
+
+ res &= head == mctx->get_head();
+ res &= rs_z == mctx->get_rs_z();
+
+ return res;
+}
+
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
}
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
- inp_attn->set_input(ubatch);
- inp_rs->set_input(ubatch);
+ mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
+ mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
+
+ mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
+
+ const int64_t n_rs = mctx->get_recr()->get_n_rs();
+
+ if (inp_rs->s_copy) {
+ GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
+ int32_t * data = (int32_t *) inp_rs->s_copy->data;
+
+ // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
+ for (uint32_t i = 0; i < n_rs; ++i) {
+ data[i] = mctx->get_recr()->s_copy(i);
+ }
+ }
+}
+
+bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
+ const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
+
+ this->mctx = mctx;
+
+ bool res = true;
+
+ res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
+ //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
+
+ res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
+ res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
+
+ res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
+
+ res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
+ res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
+
+ res &= inp_rs->head == mctx->get_recr()->get_head();
+ res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
+
+ return res;
}
//
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
+ inp->head = mctx_cur->get_head();
+ inp->rs_z = mctx_cur->get_rs_z();
+
return inp;
}
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
- auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
+ auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr());
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
- auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
}
void set_input(const llama_ubatch * ubatch) override;
+ bool can_reuse(const llm_graph_params & params) override;
+
ggml_tensor * s_copy; // I32 [n_rs]
// views of s_copy, computed once per graph
ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
const llama_memory_recurrent_context * mctx;
+
+ // used in view offsets, need to match for valid graph reuse
+ uint32_t head;
+ int32_t rs_z;
};
class llm_graph_input_cross_embd : public llm_graph_input_i {
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
public:
llm_graph_input_mem_hybrid(
+ const llama_cparams & cparams,
std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
- std::unique_ptr<llm_graph_input_rs> inp_rs,
- const llama_memory_hybrid_context * mctx) :
+ std::unique_ptr<llm_graph_input_rs> inp_rs,
+ const llama_memory_hybrid_context * mctx) :
inp_attn(std::move(inp_attn)),
inp_rs(std::move(inp_rs)),
+ cparams(cparams),
mctx(mctx) { }
virtual ~llm_graph_input_mem_hybrid() = default;
void set_input(const llama_ubatch * ubatch) override;
+ bool can_reuse(const llm_graph_params & params) override;
+
std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
std::unique_ptr<llm_graph_input_rs> inp_rs;
llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
+ const llama_cparams cparams;
+
const llama_memory_hybrid_context * mctx;
};