]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : rotate activations for better quantization (#21038)
authorGeorgi Gerganov <redacted>
Wed, 1 Apr 2026 13:58:01 +0000 (16:58 +0300)
committerGitHub <redacted>
Wed, 1 Apr 2026 13:58:01 +0000 (16:58 +0300)
* llama : rotate activations for better quantization

* cont : rotate V more + refactor

* cont : rotate caches separately + support non-power-of-2 head sizes

* cont : simplify

* cont : add reference for V rotation

* cont : refactor

* cont : support context shift

* cont : consolidate

* cont : dedup + allow different types for the rotation matrix

* cont : add env variable to disable rotation

* cont : simplify attn rot kv cache logic + rename env

* cont : pre-compute the Hadamard matrices

src/llama-graph.cpp
src/llama-graph.h
src/llama-kv-cache.cpp
src/llama-kv-cache.h

index c2833b75ced784f9ff120d9f5adf0c93f61bc9d7..0e7d96ca10d457c3411cfa3b2f4c4a28a91b6878 100644 (file)
@@ -19,7 +19,7 @@
 
 // dedup helpers
 
-static ggml_tensor * build_kq_mask(
+static ggml_tensor * build_attn_inp_kq_mask(
         ggml_context * ctx,
         const llama_kv_cache_context * mctx,
         const llama_ubatch & ubatch,
@@ -28,7 +28,11 @@ static ggml_tensor * build_kq_mask(
     const auto n_tokens = ubatch.n_tokens;
     const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
 
-    return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
+    ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
+    ggml_set_input(res);
+    ggml_set_name(res, "attn_inp_kq_mask");
+
+    return res;
 }
 
 static bool can_reuse_kq_mask(
@@ -52,6 +56,21 @@ static bool can_reuse_kq_mask(
 
 // impl
 
+static ggml_tensor * ggml_mul_mat_aux(
+        ggml_context * ctx,
+        ggml_tensor * cur,
+        ggml_tensor * rot) {
+    const auto n = rot->ne[0];
+
+    ggml_tensor * res;
+
+    res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
+    res = ggml_mul_mat   (ctx, rot, res);
+    res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
+
+    return res;
+}
+
 void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
     if (ubatch->token) {
         const int64_t n_tokens = ubatch->n_tokens;
@@ -429,6 +448,14 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
     mctx->set_input_v_idxs(self_v_idxs, ubatch);
 
     mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
+
+    if (self_k_rot) {
+        mctx->set_input_k_rot(self_k_rot);
+    }
+
+    if (self_v_rot) {
+        mctx->set_input_v_rot(self_v_rot);
+    }
 }
 
 bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
@@ -476,6 +503,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
     mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
 
     mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
+
+    if (self_k_rot) {
+        mctx->get_base()->set_input_k_rot(self_k_rot);
+    }
+
+    if (self_v_rot) {
+        mctx->get_base()->set_input_v_rot(self_v_rot);
+    }
 }
 
 bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
@@ -532,6 +567,14 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
 
     mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
 
+    if (inp_attn->self_k_rot) {
+        mctx->get_attn()->set_input_k_rot(inp_attn->self_k_rot);
+    }
+
+    if (inp_attn->self_v_rot) {
+        mctx->get_attn()->set_input_v_rot(inp_attn->self_v_rot);
+    }
+
     const int64_t n_rs = mctx->get_recr()->get_n_rs();
 
     if (inp_rs->s_copy) {
@@ -630,6 +673,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
         attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
     }
 
+    if (inp_attn->self_k_rot) {
+        attn_ctx->get_base()->set_input_k_rot(inp_attn->self_k_rot);
+    }
+
+    if (inp_attn->self_v_rot) {
+        attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot);
+    }
+
     const int64_t n_rs = mctx->get_recr()->get_n_rs();
 
     if (inp_rs->s_copy) {
@@ -2002,13 +2053,13 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
         inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
         inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
-
-        ggml_set_input(inp->self_kq_mask);
-
+        inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
     }
 
+    inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0);
+    inp->self_v_rot = mctx_cur->build_input_v_rot(ctx0);
+
     return inp;
 }
 
@@ -2034,6 +2085,15 @@ ggml_tensor * llm_graph_context::build_attn(
             int       il) const {
     GGML_ASSERT(v_mla == nullptr);
 
+    if (inp->self_k_rot) {
+        q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
+        k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
+    }
+
+    if (inp->self_v_rot) {
+        v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
+    }
+
     // these nodes are added to the graph together so that they are not reordered
     // by doing so, the number of splits in the graph is reduced
     // expand k later to enable rope fusion which directly writes into k-v cache
@@ -2061,6 +2121,10 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
     cb(cur, "kqv_out", il);
 
+    if (inp->self_v_rot) {
+        cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
+    }
+
     if (wo) {
         cur = build_lora_mm(wo, cur);
         if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
@@ -2090,9 +2154,7 @@ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
 
         inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
-        ggml_set_input(inp->self_kq_mask);
-
+        inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
     }
 
@@ -2171,6 +2233,18 @@ ggml_tensor * llm_graph_context::build_attn(
         ggml_tensor * v_mla,
             float     kq_scale,
             int       il) const {
+    if (inp->self_k_rot) {
+        q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
+        if (k_cur) {
+            k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
+        }
+    }
+    if (inp->self_v_rot) {
+        if (v_cur) {
+            v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
+        }
+    }
+
     // these nodes are added to the graph together so that they are not reordered
     // by doing so, the number of splits in the graph is reduced
     ggml_build_forward_expand(gf, q_cur);
@@ -2211,6 +2285,10 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
     cb(cur, "kqv_out", il);
 
+    if (inp->self_v_rot) {
+        cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
+    }
+
     if (wo) {
         cur = build_lora_mm(wo, cur);
     }
@@ -2293,12 +2371,8 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
         inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
         inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
-        ggml_set_input(inp->self_kq_mask);
-        ggml_set_name(inp->self_kq_mask, "self_kq_mask");
-
+        inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
-        ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
     }
 
     {
@@ -2307,14 +2381,13 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
         inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
         inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
-        ggml_set_input(inp->self_kq_mask_swa);
-        ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
-
+        inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
         inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
-        ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
     }
 
+    inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
+    inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0);
+
     return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
 }
 
@@ -2473,9 +2546,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
         inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
         inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
 
-        inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
-        ggml_set_input(inp_attn->self_kq_mask);
-
+        inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
         inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
     }
 
@@ -2483,9 +2554,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
         inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
         inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
 
-        inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
-        ggml_set_input(inp_attn->self_kq_mask_swa);
-
+        inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
         inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
     }
 
index 4855685ef7144a3c1f82cc84ca34932fe3d8ed71..bb0ad75198feffbb7f0639215cdf4e6c2e1d2269 100644 (file)
@@ -308,6 +308,10 @@ public:
     ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
     ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
 
+    // note: assumes v_rot^ == I
+    ggml_tensor * self_k_rot = nullptr;
+    ggml_tensor * self_v_rot = nullptr;
+
     // note: these have to be copies because in order to be able to reuse a graph, its inputs
     //       need to carry these parameters with them. otherwise, they can point to freed
     //       llm_graph_params from a previous batch, causing stack-use-after-return
@@ -384,6 +388,10 @@ public:
     ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
     ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
 
+    // note: using same rotation matrices for both base and swa cache
+    ggml_tensor * self_k_rot = nullptr;
+    ggml_tensor * self_v_rot = nullptr;
+
     const llama_hparams hparams;
     const llama_cparams cparams;
 
index 5f57ba9e1d8fb485b13275ff9f8f0201528ebeff..3e0fd3107f307c43b47e019094b68942745eed99 100644 (file)
 #include <map>
 #include <stdexcept>
 
+static bool ggml_is_power_of_2(int n) {
+    return (n & (n - 1)) == 0;
+}
+
+// orthonormal Walsh-Hadamard rotation matrix
+// note: res^2 == I
+static void ggml_gen_hadamard(ggml_tensor * tensor) {
+    assert(tensor->type == GGML_TYPE_F32);
+
+    const int n = tensor->ne[0];
+
+    assert(ggml_is_power_of_2(n));
+    assert(tensor->ne[1] == n);
+    assert(tensor->ne[2] == 1);
+    assert(tensor->ne[3] == 1);
+
+    std::vector<float> data_f32;
+
+    float * data = (float *) tensor->data;
+
+    if (tensor->type != GGML_TYPE_F32) {
+        data_f32.resize(n*n);
+        data = data_f32.data();
+    }
+
+    data[0*n + 0] = 1.0 / sqrtf(n);
+
+    for (int s = 1; s < n; s *= 2) {
+        for (int i = 0; i < s; i++) {
+            for (int j = 0; j < s; j++) {
+                const float val = data[i*n + j];
+
+                data[(i + s)*n + (j    )] =  val;
+                data[(i    )*n + (j + s)] =  val;
+                data[(i + s)*n + (j + s)] = -val;
+            }
+        }
+    }
+
+    if (tensor->type != GGML_TYPE_F32) {
+        ggml_quantize_chunk(tensor->type, data, tensor->data, 0, 1, n*n, nullptr);
+    }
+}
+
+static ggml_tensor * ggml_mul_mat_aux(
+        ggml_context * ctx,
+        ggml_tensor * cur,
+        ggml_tensor * rot) {
+    const auto n = rot->ne[0];
+
+    ggml_tensor * res;
+
+    res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
+    res = ggml_mul_mat   (ctx, rot, res);
+    res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
+
+    return res;
+}
+
 //
 // llama_kv_cache
 //
@@ -209,6 +268,48 @@ llama_kv_cache::llama_kv_cache(
                 ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
     }
 
+    const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE");
+    const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false;
+    if (attn_rot_disable) {
+        LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__);
+    }
+
+    attn_rot_k =
+        !attn_rot_disable &&
+        ggml_is_quantized(type_k) &&
+        !hparams.is_n_embd_k_gqa_variable() &&
+        hparams.n_embd_head_k() % 64 == 0;
+
+    attn_rot_v =
+        !attn_rot_disable &&
+        ggml_is_quantized(type_v) &&
+        !hparams.is_n_embd_v_gqa_variable() &&
+        hparams.n_embd_head_v() % 64 == 0;
+
+    LLAMA_LOG_INFO("%s: attn_rot_k = %d\n", __func__, attn_rot_k);
+    LLAMA_LOG_INFO("%s: attn_rot_v = %d\n", __func__, attn_rot_v);
+
+    // pre-compute the haramard matrices and keep them in host memory
+    // TODO: in the future, we can make copies in the backend buffers to avoid host -> device transfers
+    if (attn_rot_k || attn_rot_v) {
+        for (int64_t n = 64; n <= std::max(hparams.n_embd_head_k(), hparams.n_embd_head_v()); n *= 2) {
+            attn_rot_hadamard[n] = std::vector<float>(n*n);
+
+            ggml_init_params params = {
+                /* .mem_size   = */ 1*ggml_tensor_overhead(),
+                /* .mem_buffer = */ nullptr,
+                /* .no_alloc   = */ true,
+            };
+
+            ggml_context_ptr ctx { ggml_init(params) };
+
+            ggml_tensor * tmp = ggml_new_tensor_2d(ctx.get(), GGML_TYPE_F32, n, n);
+            tmp->data = attn_rot_hadamard[n].data();
+
+            ggml_gen_hadamard(tmp);
+        }
+    }
+
     const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
     debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
 }
@@ -1004,6 +1105,14 @@ bool llama_kv_cache::get_has_shift() const {
     return result;
 }
 
+ggml_type llama_kv_cache::type_k() const {
+    return layers[0].k->type;
+}
+
+ggml_type llama_kv_cache::type_v() const {
+    return layers[0].v->type;
+}
+
 uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
     uint32_t result = 0;
 
@@ -1189,6 +1298,47 @@ ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama
     return v_idxs;
 }
 
+ggml_tensor * llama_kv_cache::build_input_k_rot(ggml_context * ctx) const {
+    ggml_tensor * res = nullptr;
+
+    if (attn_rot_k) {
+        int nrot = 64;
+
+        // TODO: investigate if using the smallest rotation matrix is beneficial also for K (similar as for V)
+        // ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088
+        do {
+            nrot *= 2;
+        } while (hparams.n_embd_head_k() % nrot == 0);
+        nrot /= 2;
+
+        res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot);
+        ggml_set_input(res);
+        ggml_set_name(res, "attn_inp_k_rot");
+    }
+
+    return res;
+}
+
+ggml_tensor * llama_kv_cache::build_input_v_rot(ggml_context * ctx) const {
+    ggml_tensor * res = nullptr;
+
+    if (attn_rot_v) {
+        int nrot = 64;
+        // using smaller rotation matrices for V seems beneficial
+        // ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4146397570
+        //do {
+        //    nrot *= 2;
+        //} while (hparams.n_embd_head_v() % nrot == 0);
+        //nrot /= 2;
+
+        res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot);
+        ggml_set_input(res);
+        ggml_set_name(res, "attn_inp_v_rot");
+    }
+
+    return res;
+}
+
 void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
     const uint32_t n_tokens = ubatch->n_tokens;
     GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
@@ -1507,6 +1657,24 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
     }
 }
 
+void llama_kv_cache::set_input_k_rot(ggml_tensor * dst) const {
+    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+
+    const auto n_rot = dst->ne[0];
+    GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0]));
+
+    memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst));
+}
+
+void llama_kv_cache::set_input_v_rot(ggml_tensor * dst) const {
+    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+
+    const auto n_rot = dst->ne[0];
+    GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0]));
+
+    memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst));
+}
+
 size_t llama_kv_cache::total_size() const {
     size_t size = 0;
 
@@ -1542,6 +1710,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
                ggml_context * ctx,
                 ggml_tensor * cur,
                 ggml_tensor * shift,
+                ggml_tensor * rot,
                 ggml_tensor * factors,
                       float   freq_base,
                       float   freq_scale,
@@ -1567,10 +1736,16 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
         // dequantize to f32 -> RoPE -> quantize back
         tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
 
+        // rotate back
+        tmp = ggml_mul_mat_aux(ctx, tmp, rot);
+
         tmp = ggml_rope_ext(ctx, tmp,
                 shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                 yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
 
+        // rotate fwd
+        tmp = ggml_mul_mat_aux(ctx, tmp, rot);
+
         tmp = ggml_cpy(ctx, tmp, cur);
     } else {
         // we rotate only the first n_rot dimensions
@@ -1591,6 +1766,9 @@ public:
 
     ggml_tensor * k_shift; // I32 [kv_size*n_stream]
 
+    // note: assumes k_rot^2 == I
+    ggml_tensor * k_rot = nullptr;
+
     const llama_kv_cache * kv_self;
 };
 
@@ -1600,6 +1778,10 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
     if (k_shift) {
         kv_self->set_input_k_shift(k_shift);
     }
+
+    if (k_rot) {
+        kv_self->set_input_k_rot(k_rot);
+    }
 }
 
 ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
@@ -1611,6 +1793,8 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
     inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
     ggml_set_input(inp->k_shift);
 
+    inp->k_rot = build_input_k_rot(ctx);
+
     const auto & cparams = lctx->get_cparams();
 
     for (const auto & layer : layers) {
@@ -1635,7 +1819,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
                 ggml_row_size(layer.k->type, n_embd_k_gqa),
                 ggml_row_size(layer.k->type, n_embd_nope));
 
-        ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il);
+        ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, inp->k_rot, rope_factors, freq_base_l, freq_scale_l, il);
 
         ggml_build_forward_expand(gf, cur);
     }
@@ -2239,6 +2423,14 @@ uint32_t llama_kv_cache_context::get_n_kv() const {
     return n_kv;
 }
 
+ggml_type llama_kv_cache_context::type_k() const {
+    return kv->type_k();
+}
+
+ggml_type llama_kv_cache_context::type_v() const {
+    return kv->type_v();
+}
+
 ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
     return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
 }
@@ -2263,6 +2455,14 @@ ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, con
     return kv->build_input_v_idxs(ctx, ubatch);
 }
 
+ggml_tensor * llama_kv_cache_context::build_input_k_rot(ggml_context * ctx) const {
+    return kv->build_input_k_rot(ctx);
+}
+
+ggml_tensor * llama_kv_cache_context::build_input_v_rot(ggml_context * ctx) const {
+    return kv->build_input_v_rot(ctx);
+}
+
 void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const {
     kv->set_input_k_shift(dst);
 }
@@ -2282,3 +2482,11 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub
 void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
     kv->set_input_pos_bucket(dst, ubatch);
 }
+
+void llama_kv_cache_context::set_input_k_rot(ggml_tensor * dst) const {
+    kv->set_input_k_rot(dst);
+}
+
+void llama_kv_cache_context::set_input_v_rot(ggml_tensor * dst) const {
+    kv->set_input_v_rot(dst);
+}
index 90a0610c49dd7b1c797fb2b05953b4bbd076f40f..d4569a06f71aa3167e2fdfa94d3a5c043485d5b0 100644 (file)
@@ -152,6 +152,9 @@ public:
 
     bool get_has_shift() const;
 
+    ggml_type type_k() const;
+    ggml_type type_v() const;
+
     //
     // graph_build API
     //
@@ -191,6 +194,9 @@ public:
     ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
     ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
 
+    ggml_tensor * build_input_k_rot(ggml_context * ctx) const;
+    ggml_tensor * build_input_v_rot(ggml_context * ctx) const;
+
     void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
     void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
 
@@ -199,6 +205,9 @@ public:
     void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
     void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
 
+    void set_input_k_rot(ggml_tensor * dst) const;
+    void set_input_v_rot(ggml_tensor * dst) const;
+
 private:
     const llama_model & model;
     const llama_hparams & hparams;
@@ -226,6 +235,13 @@ private:
     // SWA
     const uint32_t n_swa = 0;
 
+    // env: LLAMA_ATTN_ROT_DISABLE
+    bool attn_rot_k = false;
+    bool attn_rot_v = false;
+
+    // pre-computed hadamard martrices
+    std::unordered_map<int64_t, std::vector<float>> attn_rot_hadamard;
+
     // env: LLAMA_KV_CACHE_DEBUG
     int debug = 0;
 
@@ -262,6 +278,7 @@ private:
                    ggml_context * ctx,
                     ggml_tensor * cur,
                     ggml_tensor * shift,
+                    ggml_tensor * rot,
                     ggml_tensor * factors,
                           float   freq_base,
                           float   freq_scale,
@@ -328,6 +345,9 @@ public:
 
     uint32_t get_n_kv() const;
 
+    ggml_type type_k() const;
+    ggml_type type_v() const;
+
     // get views of the current state of the cache
     ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
     ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
@@ -347,6 +367,9 @@ public:
     ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
     ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
 
+    ggml_tensor * build_input_k_rot(ggml_context * ctx) const;
+    ggml_tensor * build_input_v_rot(ggml_context * ctx) const;
+
     void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
     void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
 
@@ -354,6 +377,9 @@ public:
     void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
     void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
 
+    void set_input_k_rot(ggml_tensor * dst) const;
+    void set_input_v_rot(ggml_tensor * dst) const;
+
 private:
     llama_memory_status status;