// dedup helpers
-static ggml_tensor * build_kq_mask(
+static ggml_tensor * build_attn_inp_kq_mask(
ggml_context * ctx,
const llama_kv_cache_context * mctx,
const llama_ubatch & ubatch,
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
- return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
+ ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
+ ggml_set_input(res);
+ ggml_set_name(res, "attn_inp_kq_mask");
+
+ return res;
}
static bool can_reuse_kq_mask(
// impl
+static ggml_tensor * ggml_mul_mat_aux(
+ ggml_context * ctx,
+ ggml_tensor * cur,
+ ggml_tensor * rot) {
+ const auto n = rot->ne[0];
+
+ ggml_tensor * res;
+
+ res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
+ res = ggml_mul_mat (ctx, rot, res);
+ res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
+
+ return res;
+}
+
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
if (ubatch->token) {
const int64_t n_tokens = ubatch->n_tokens;
mctx->set_input_v_idxs(self_v_idxs, ubatch);
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
+
+ if (self_k_rot) {
+ mctx->set_input_k_rot(self_k_rot);
+ }
+
+ if (self_v_rot) {
+ mctx->set_input_v_rot(self_v_rot);
+ }
}
bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
+
+ if (self_k_rot) {
+ mctx->get_base()->set_input_k_rot(self_k_rot);
+ }
+
+ if (self_v_rot) {
+ mctx->get_base()->set_input_v_rot(self_v_rot);
+ }
}
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
+ if (inp_attn->self_k_rot) {
+ mctx->get_attn()->set_input_k_rot(inp_attn->self_k_rot);
+ }
+
+ if (inp_attn->self_v_rot) {
+ mctx->get_attn()->set_input_v_rot(inp_attn->self_v_rot);
+ }
+
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (inp_rs->s_copy) {
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
}
+ if (inp_attn->self_k_rot) {
+ attn_ctx->get_base()->set_input_k_rot(inp_attn->self_k_rot);
+ }
+
+ if (inp_attn->self_v_rot) {
+ attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot);
+ }
+
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (inp_rs->s_copy) {
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
- inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
-
- ggml_set_input(inp->self_kq_mask);
-
+ inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
+ inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0);
+ inp->self_v_rot = mctx_cur->build_input_v_rot(ctx0);
+
return inp;
}
int il) const {
GGML_ASSERT(v_mla == nullptr);
+ if (inp->self_k_rot) {
+ q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
+ k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
+ }
+
+ if (inp->self_v_rot) {
+ v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
+ }
+
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
// expand k later to enable rope fusion which directly writes into k-v cache
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
+ if (inp->self_v_rot) {
+ cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
+ }
+
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
- inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
- ggml_set_input(inp->self_kq_mask);
-
+ inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
ggml_tensor * v_mla,
float kq_scale,
int il) const {
+ if (inp->self_k_rot) {
+ q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
+ if (k_cur) {
+ k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
+ }
+ }
+ if (inp->self_v_rot) {
+ if (v_cur) {
+ v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
+ }
+ }
+
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, q_cur);
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
+ if (inp->self_v_rot) {
+ cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
+ }
+
if (wo) {
cur = build_lora_mm(wo, cur);
}
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
- inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
- ggml_set_input(inp->self_kq_mask);
- ggml_set_name(inp->self_kq_mask, "self_kq_mask");
-
+ inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
- ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
}
{
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
- inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
- ggml_set_input(inp->self_kq_mask_swa);
- ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
-
+ inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
- ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
}
+ inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
+ inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0);
+
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
}
inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
- inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
- ggml_set_input(inp_attn->self_kq_mask);
-
+ inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
}
inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
- inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
- ggml_set_input(inp_attn->self_kq_mask_swa);
-
+ inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
}
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
+ // note: assumes v_rot^ == I
+ ggml_tensor * self_k_rot = nullptr;
+ ggml_tensor * self_v_rot = nullptr;
+
// note: these have to be copies because in order to be able to reuse a graph, its inputs
// need to carry these parameters with them. otherwise, they can point to freed
// llm_graph_params from a previous batch, causing stack-use-after-return
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
+ // note: using same rotation matrices for both base and swa cache
+ ggml_tensor * self_k_rot = nullptr;
+ ggml_tensor * self_v_rot = nullptr;
+
const llama_hparams hparams;
const llama_cparams cparams;
#include <map>
#include <stdexcept>
+static bool ggml_is_power_of_2(int n) {
+ return (n & (n - 1)) == 0;
+}
+
+// orthonormal Walsh-Hadamard rotation matrix
+// note: res^2 == I
+static void ggml_gen_hadamard(ggml_tensor * tensor) {
+ assert(tensor->type == GGML_TYPE_F32);
+
+ const int n = tensor->ne[0];
+
+ assert(ggml_is_power_of_2(n));
+ assert(tensor->ne[1] == n);
+ assert(tensor->ne[2] == 1);
+ assert(tensor->ne[3] == 1);
+
+ std::vector<float> data_f32;
+
+ float * data = (float *) tensor->data;
+
+ if (tensor->type != GGML_TYPE_F32) {
+ data_f32.resize(n*n);
+ data = data_f32.data();
+ }
+
+ data[0*n + 0] = 1.0 / sqrtf(n);
+
+ for (int s = 1; s < n; s *= 2) {
+ for (int i = 0; i < s; i++) {
+ for (int j = 0; j < s; j++) {
+ const float val = data[i*n + j];
+
+ data[(i + s)*n + (j )] = val;
+ data[(i )*n + (j + s)] = val;
+ data[(i + s)*n + (j + s)] = -val;
+ }
+ }
+ }
+
+ if (tensor->type != GGML_TYPE_F32) {
+ ggml_quantize_chunk(tensor->type, data, tensor->data, 0, 1, n*n, nullptr);
+ }
+}
+
+static ggml_tensor * ggml_mul_mat_aux(
+ ggml_context * ctx,
+ ggml_tensor * cur,
+ ggml_tensor * rot) {
+ const auto n = rot->ne[0];
+
+ ggml_tensor * res;
+
+ res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
+ res = ggml_mul_mat (ctx, rot, res);
+ res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
+
+ return res;
+}
+
//
// llama_kv_cache
//
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
+ const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE");
+ const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false;
+ if (attn_rot_disable) {
+ LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__);
+ }
+
+ attn_rot_k =
+ !attn_rot_disable &&
+ ggml_is_quantized(type_k) &&
+ !hparams.is_n_embd_k_gqa_variable() &&
+ hparams.n_embd_head_k() % 64 == 0;
+
+ attn_rot_v =
+ !attn_rot_disable &&
+ ggml_is_quantized(type_v) &&
+ !hparams.is_n_embd_v_gqa_variable() &&
+ hparams.n_embd_head_v() % 64 == 0;
+
+ LLAMA_LOG_INFO("%s: attn_rot_k = %d\n", __func__, attn_rot_k);
+ LLAMA_LOG_INFO("%s: attn_rot_v = %d\n", __func__, attn_rot_v);
+
+ // pre-compute the haramard matrices and keep them in host memory
+ // TODO: in the future, we can make copies in the backend buffers to avoid host -> device transfers
+ if (attn_rot_k || attn_rot_v) {
+ for (int64_t n = 64; n <= std::max(hparams.n_embd_head_k(), hparams.n_embd_head_v()); n *= 2) {
+ attn_rot_hadamard[n] = std::vector<float>(n*n);
+
+ ggml_init_params params = {
+ /* .mem_size = */ 1*ggml_tensor_overhead(),
+ /* .mem_buffer = */ nullptr,
+ /* .no_alloc = */ true,
+ };
+
+ ggml_context_ptr ctx { ggml_init(params) };
+
+ ggml_tensor * tmp = ggml_new_tensor_2d(ctx.get(), GGML_TYPE_F32, n, n);
+ tmp->data = attn_rot_hadamard[n].data();
+
+ ggml_gen_hadamard(tmp);
+ }
+ }
+
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
}
return result;
}
+ggml_type llama_kv_cache::type_k() const {
+ return layers[0].k->type;
+}
+
+ggml_type llama_kv_cache::type_v() const {
+ return layers[0].v->type;
+}
+
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
uint32_t result = 0;
return v_idxs;
}
+ggml_tensor * llama_kv_cache::build_input_k_rot(ggml_context * ctx) const {
+ ggml_tensor * res = nullptr;
+
+ if (attn_rot_k) {
+ int nrot = 64;
+
+ // TODO: investigate if using the smallest rotation matrix is beneficial also for K (similar as for V)
+ // ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088
+ do {
+ nrot *= 2;
+ } while (hparams.n_embd_head_k() % nrot == 0);
+ nrot /= 2;
+
+ res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot);
+ ggml_set_input(res);
+ ggml_set_name(res, "attn_inp_k_rot");
+ }
+
+ return res;
+}
+
+ggml_tensor * llama_kv_cache::build_input_v_rot(ggml_context * ctx) const {
+ ggml_tensor * res = nullptr;
+
+ if (attn_rot_v) {
+ int nrot = 64;
+ // using smaller rotation matrices for V seems beneficial
+ // ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4146397570
+ //do {
+ // nrot *= 2;
+ //} while (hparams.n_embd_head_v() % nrot == 0);
+ //nrot /= 2;
+
+ res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot);
+ ggml_set_input(res);
+ ggml_set_name(res, "attn_inp_v_rot");
+ }
+
+ return res;
+}
+
void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
const uint32_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
}
}
+void llama_kv_cache::set_input_k_rot(ggml_tensor * dst) const {
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+
+ const auto n_rot = dst->ne[0];
+ GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0]));
+
+ memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst));
+}
+
+void llama_kv_cache::set_input_v_rot(ggml_tensor * dst) const {
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+
+ const auto n_rot = dst->ne[0];
+ GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0]));
+
+ memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst));
+}
+
size_t llama_kv_cache::total_size() const {
size_t size = 0;
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * shift,
+ ggml_tensor * rot,
ggml_tensor * factors,
float freq_base,
float freq_scale,
// dequantize to f32 -> RoPE -> quantize back
tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
+ // rotate back
+ tmp = ggml_mul_mat_aux(ctx, tmp, rot);
+
tmp = ggml_rope_ext(ctx, tmp,
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
+ // rotate fwd
+ tmp = ggml_mul_mat_aux(ctx, tmp, rot);
+
tmp = ggml_cpy(ctx, tmp, cur);
} else {
// we rotate only the first n_rot dimensions
ggml_tensor * k_shift; // I32 [kv_size*n_stream]
+ // note: assumes k_rot^2 == I
+ ggml_tensor * k_rot = nullptr;
+
const llama_kv_cache * kv_self;
};
if (k_shift) {
kv_self->set_input_k_shift(k_shift);
}
+
+ if (k_rot) {
+ kv_self->set_input_k_rot(k_rot);
+ }
}
ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
ggml_set_input(inp->k_shift);
+ inp->k_rot = build_input_k_rot(ctx);
+
const auto & cparams = lctx->get_cparams();
for (const auto & layer : layers) {
ggml_row_size(layer.k->type, n_embd_k_gqa),
ggml_row_size(layer.k->type, n_embd_nope));
- ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il);
+ ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, inp->k_rot, rope_factors, freq_base_l, freq_scale_l, il);
ggml_build_forward_expand(gf, cur);
}
return n_kv;
}
+ggml_type llama_kv_cache_context::type_k() const {
+ return kv->type_k();
+}
+
+ggml_type llama_kv_cache_context::type_v() const {
+ return kv->type_v();
+}
+
ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
}
return kv->build_input_v_idxs(ctx, ubatch);
}
+ggml_tensor * llama_kv_cache_context::build_input_k_rot(ggml_context * ctx) const {
+ return kv->build_input_k_rot(ctx);
+}
+
+ggml_tensor * llama_kv_cache_context::build_input_v_rot(ggml_context * ctx) const {
+ return kv->build_input_v_rot(ctx);
+}
+
void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const {
kv->set_input_k_shift(dst);
}
void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_pos_bucket(dst, ubatch);
}
+
+void llama_kv_cache_context::set_input_k_rot(ggml_tensor * dst) const {
+ kv->set_input_k_rot(dst);
+}
+
+void llama_kv_cache_context::set_input_v_rot(ggml_tensor * dst) const {
+ kv->set_input_v_rot(dst);
+}
bool get_has_shift() const;
+ ggml_type type_k() const;
+ ggml_type type_v() const;
+
//
// graph_build API
//
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
+ ggml_tensor * build_input_k_rot(ggml_context * ctx) const;
+ ggml_tensor * build_input_v_rot(ggml_context * ctx) const;
+
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
+ void set_input_k_rot(ggml_tensor * dst) const;
+ void set_input_v_rot(ggml_tensor * dst) const;
+
private:
const llama_model & model;
const llama_hparams & hparams;
// SWA
const uint32_t n_swa = 0;
+ // env: LLAMA_ATTN_ROT_DISABLE
+ bool attn_rot_k = false;
+ bool attn_rot_v = false;
+
+ // pre-computed hadamard martrices
+ std::unordered_map<int64_t, std::vector<float>> attn_rot_hadamard;
+
// env: LLAMA_KV_CACHE_DEBUG
int debug = 0;
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * shift,
+ ggml_tensor * rot,
ggml_tensor * factors,
float freq_base,
float freq_scale,
uint32_t get_n_kv() const;
+ ggml_type type_k() const;
+ ggml_type type_v() const;
+
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
+ ggml_tensor * build_input_k_rot(ggml_context * ctx) const;
+ ggml_tensor * build_input_v_rot(ggml_context * ctx) const;
+
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
+ void set_input_k_rot(ggml_tensor * dst) const;
+ void set_input_v_rot(ggml_tensor * dst) const;
+
private:
llama_memory_status status;