]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kv-cache : support layer reuse (#15504)
authorGeorgi Gerganov <redacted>
Sun, 24 Aug 2025 10:07:07 +0000 (13:07 +0300)
committerGitHub <redacted>
Sun, 24 Aug 2025 10:07:07 +0000 (13:07 +0300)
* kv-cache : support layer reuse

ggml-ci

* cont : update comments [no ci]

12 files changed:
src/llama-hparams.cpp
src/llama-hparams.h
src/llama-kv-cache-iswa.cpp
src/llama-kv-cache-iswa.h
src/llama-kv-cache.cpp
src/llama-kv-cache.h
src/llama-memory-hybrid.cpp
src/llama-memory-hybrid.h
src/llama-memory-recurrent.cpp
src/llama-memory-recurrent.h
src/llama-memory.h
src/llama-model.cpp

index 7a06368dcda68e1f133fae49f59008db4fb8af07..91636572da8b29d61722c59c301806cd280397d0 100644 (file)
@@ -153,3 +153,28 @@ bool llama_hparams::is_swa(uint32_t il) const {
 
     GGML_ABORT("fatal error");
 }
+
+bool llama_hparams::has_kv(uint32_t il) const {
+    if (n_layer_kv_from_start >= 0) {
+        if (il < (uint32_t) n_layer_kv_from_start) {
+            return true;
+        }
+
+        return false;
+    }
+
+    // by default, all layers have kv
+    return true;
+}
+
+uint32_t llama_hparams::n_layer_kv() const {
+    uint32_t res = 0;
+
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        if (has_kv(il)) {
+            res++;
+        }
+    }
+
+    return res;
+}
index bd23122443271b7fc3d078ad16a8a19e618bba97..60415f0c202a4a039456fceaea898979dd2d396f 100644 (file)
@@ -41,6 +41,7 @@ struct llama_hparams {
     uint32_t n_embd;
     uint32_t n_embd_features = 0;
     uint32_t n_layer;
+     int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
     uint32_t n_rot;
     uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
     uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
@@ -221,6 +222,11 @@ struct llama_hparams {
     uint32_t n_pos_per_embd() const;
 
     bool is_swa(uint32_t il) const;
+
+    bool has_kv(uint32_t il) const;
+
+    // number of layers for which has_kv() returns true
+    uint32_t n_layer_kv() const;
 };
 
 static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
index a11ee5a5b185d0c03e2d3f94219296942095b1eb..d7342914c6b7cbf17eeb8cae258b12413e16d6e7 100644 (file)
@@ -22,9 +22,26 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
                  uint32_t   kv_size,
                  uint32_t   n_seq_max,
                  uint32_t   n_ubatch,
-                 uint32_t   n_pad) : hparams(model.hparams), unified(unified) {
-    llama_kv_cache::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
-    llama_kv_cache::layer_filter_cb filter_swa  = [&](int32_t il) { return  model.hparams.is_swa(il); };
+                 uint32_t   n_pad,
+    const layer_filter_cb & filter,
+    const  layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
+
+    // chain filters
+    const layer_filter_cb filter_base = [&](int32_t il) {
+        if (filter && !filter(il)) {
+            return false;
+        }
+
+        return !model.hparams.is_swa(il);
+    };
+
+    const layer_filter_cb filter_swa  = [&](int32_t il) {
+        if (filter && !filter(il)) {
+            return false;
+        }
+
+        return  model.hparams.is_swa(il);
+    };
 
     const uint32_t size_base = kv_size;
 
@@ -41,16 +58,16 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
     LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
 
     kv_base = std::make_unique<llama_kv_cache>(
-            model, std::move(filter_base), type_k, type_v,
+            model, type_k, type_v,
             v_trans, offload, unified, size_base, n_seq_max, n_pad,
-            0, LLAMA_SWA_TYPE_NONE);
+            0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
 
     LLAMA_LOG_INFO("%s: creating     SWA KV cache, size = %u cells\n", __func__, size_swa);
 
     kv_swa = std::make_unique<llama_kv_cache>(
-            model, std::move(filter_swa), type_k, type_v,
+            model, type_k, type_v,
             v_trans, offload, unified, size_swa, n_seq_max, n_pad,
-            hparams.n_swa, hparams.swa_type);
+            hparams.n_swa, hparams.swa_type, filter_swa, reuse);
 }
 
 void llama_kv_cache_iswa::clear(bool data) {
index dd673f18e7e08d847d047d14d15fa0b8f34ce1cc..5ed134b7958005ab3ccc185d93ba6c505a96f961 100644 (file)
@@ -20,11 +20,13 @@ public:
                          bool   v_trans,
                          bool   offload,
                          bool   swa_full,
-                         bool  ,
+                         bool   unified,
                      uint32_t   kv_size,
                      uint32_t   n_seq_max,
                      uint32_t   n_ubatch,
-                     uint32_t   n_pad);
+                     uint32_t   n_pad,
+        const layer_filter_cb & filter,
+        const  layer_reuse_cb & reuse);
 
     ~llama_kv_cache_iswa() = default;
 
index 70ddd5f4b952cda367ab6b18a9abee8863533556..d7ab56ccd9aacf9f7d356067ae1b0cd51bf5c878 100644 (file)
 //
 
 llama_kv_cache::llama_kv_cache(
-        const llama_model &  model,
-          layer_filter_cb && filter,
-                ggml_type    type_k,
-                ggml_type    type_v,
-                     bool    v_trans,
-                     bool    offload,
-                     bool    unified,
-                 uint32_t    kv_size,
-                 uint32_t    n_seq_max,
-                 uint32_t    n_pad,
-                 uint32_t    n_swa,
-           llama_swa_type    swa_type) :
+        const llama_model & model,
+                ggml_type   type_k,
+                ggml_type   type_v,
+                     bool   v_trans,
+                     bool   offload,
+                     bool   unified,
+                 uint32_t   kv_size,
+                 uint32_t   n_seq_max,
+                 uint32_t   n_pad,
+                 uint32_t   n_swa,
+           llama_swa_type   swa_type,
+    const layer_filter_cb & filter,
+    const  layer_reuse_cb & reuse) :
     model(model), hparams(model.hparams), v_trans(v_trans),
     n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
 
     GGML_ASSERT(kv_size % n_pad == 0);
 
-    // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
-    auto n_layer_cache = hparams.n_layer;
-    if (model.arch == LLM_ARCH_GEMMA3N) {
-        n_layer_cache = 20;
-    }
-    if (model.arch == LLM_ARCH_GLM4_MOE) {
-        // GLM-4.5: Only process up to last layer, skip final NextN layer
-        n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers;
-    }
+    const uint32_t n_layer_kv = hparams.n_layer_kv();
 
     // create a context for each buffer type
     std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
@@ -50,7 +43,7 @@ llama_kv_cache::llama_kv_cache(
         auto it = ctx_map.find(buft);
         if (it == ctx_map.end()) {
             ggml_init_params params = {
-                /*.mem_size   =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()),
+                /*.mem_size   =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
                 /*.mem_buffer =*/ NULL,
                 /*.no_alloc   =*/ true,
             };
@@ -97,9 +90,14 @@ llama_kv_cache::llama_kv_cache(
                 __func__, hparams.n_embd_v_gqa_max());
     }
 
-    for (uint32_t il = 0; il < n_layer_cache; il++) {
+    for (uint32_t il = 0; il < hparams.n_layer; il++) {
+        if (!hparams.has_kv(il)) {
+            LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
+            continue;
+        }
+
         if (filter && !filter(il)) {
-            LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
+            LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il);
             continue;
         }
 
@@ -147,23 +145,27 @@ llama_kv_cache::llama_kv_cache(
         layers.push_back({ il, k, v, k_stream, v_stream, });
     }
 
-    // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
-    if (model.arch == LLM_ARCH_GEMMA3N) {
-        LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
+    if (reuse) {
+        LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__);
 
-        for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
-            if (filter && !filter(il)) {
-                LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
+        for (uint32_t il = 0; il < hparams.n_layer; il++) {
+            const int32_t il_reuse = reuse(il);
+
+            if (il_reuse < 0) {
+                LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il);
                 continue;
             }
 
-            const bool     is_swa   = hparams.is_swa(il);
-            const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
+            if (filter && !filter(il)) {
+                LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il);
+                continue;
+            }
 
             GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
+
             map_layer_ids[il] = map_layer_ids[il_reuse];
 
-            LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
+            LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il));
         }
     }
 
index 297a0973dd467a641d8a5f945a919018bdb85e15..76a5cb1e28e7e642b31ab7376fd7bc5305221424 100644 (file)
@@ -21,9 +21,6 @@ class llama_kv_cache : public llama_memory_i {
 public:
     static uint32_t get_padding(const llama_cparams & cparams);
 
-    // this callback is used to filter out layers that should not be included in the cache
-    using layer_filter_cb = std::function<bool(int32_t il)>;
-
     struct stream_copy_info {
         bool empty() const {
             assert(ssrc.size() == sdst.size());
@@ -82,18 +79,19 @@ public:
     using slot_info_vec_t = std::vector<slot_info>;
 
     llama_kv_cache(
-            const llama_model &  model,
-              layer_filter_cb && filter,
-                    ggml_type    type_k,
-                    ggml_type    type_v,
-                         bool    v_trans,
-                         bool    offload,
-                         bool    unified,
-                     uint32_t    kv_size,
-                     uint32_t    n_seq_max,
-                     uint32_t    n_pad,
-                     uint32_t    n_swa,
-               llama_swa_type    swa_type);
+            const llama_model & model,
+                    ggml_type   type_k,
+                    ggml_type   type_v,
+                         bool   v_trans,
+                         bool   offload,
+                         bool   unified,
+                     uint32_t   kv_size,
+                     uint32_t   n_seq_max,
+                     uint32_t   n_pad,
+                     uint32_t   n_swa,
+               llama_swa_type   swa_type,
+        const layer_filter_cb & filter,
+        const  layer_reuse_cb & reuse);
 
     ~llama_kv_cache() = default;
 
index f8303dacbf8adf952e7a5b583f746b44cd9a58a2..ba61ebaa885feffc62ac012563736f59fa4325eb 100644 (file)
@@ -9,32 +9,29 @@
 //
 
 llama_memory_hybrid::llama_memory_hybrid(
-    const llama_model & model,
-                         /* attn */
-            ggml_type    type_k,
-            ggml_type    type_v,
-                 bool    v_trans,
-             uint32_t    kv_size,
-             uint32_t    n_pad,
-             uint32_t    n_swa,
-       llama_swa_type    swa_type,
-                         /* recurrent */
-            ggml_type    type_r,
-            ggml_type    type_s,
-             uint32_t    rs_size,
-                         /* common */
-             uint32_t    n_seq_max,
-                 bool    offload,
-                 bool    unified,
-                         /* layer filters */
-      layer_filter_cb && filter_attn,
-      layer_filter_cb && filter_recr) :
+        const llama_model & model,
+                            /* attn */
+                ggml_type   type_k,
+                ggml_type   type_v,
+                     bool   v_trans,
+                 uint32_t   kv_size,
+                 uint32_t   n_pad,
+                 uint32_t   n_swa,
+           llama_swa_type   swa_type,
+                            /* recurrent */
+                ggml_type   type_r,
+                ggml_type   type_s,
+                 uint32_t   rs_size,
+                            /* common */
+                 uint32_t   n_seq_max,
+                     bool   offload,
+                     bool   unified,
+                            /* layer filters */
+    const layer_filter_cb & filter_attn,
+    const layer_filter_cb & filter_recr) :
     hparams(model.hparams),
     mem_attn(new llama_kv_cache(
         model,
-        filter_attn == nullptr ?
-            [&](int32_t il) { return !hparams.is_recurrent(il); }
-            : filter_attn,
         type_k,
         type_v,
         v_trans,
@@ -44,18 +41,22 @@ llama_memory_hybrid::llama_memory_hybrid(
         n_seq_max,
         n_pad,
         n_swa,
-        swa_type
+        swa_type,
+        filter_attn == nullptr ?
+            [&](int32_t il) { return !hparams.is_recurrent(il); }
+            : filter_attn,
+        nullptr
     )),
     mem_recr(new llama_memory_recurrent(
         model,
-        filter_recr == nullptr ?
-            [&](int32_t il) { return hparams.is_recurrent(il); }
-            : filter_recr,
         type_r,
         type_s,
         offload,
         rs_size,
-        n_seq_max
+        n_seq_max,
+        filter_recr == nullptr ?
+            [&](int32_t il) { return hparams.is_recurrent(il); }
+            : filter_recr
     )) {}
 
 llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
index e9c64ee40aae42a8ca8edaa7c6af1d826b789b3b..11a35651782974023d48be9ddba531c0c79e0a26 100644 (file)
 
 class llama_memory_hybrid : public llama_memory_i {
 public:
-
-    // this callback is used to filter out layers that should not be included in the cache
-    using layer_filter_cb = std::function<bool(int32_t il)>;
-
     llama_memory_hybrid(
         const llama_model & model,
                             /* attn */
-                ggml_type    type_k,
-                ggml_type    type_v,
-                     bool    v_trans,
-                 uint32_t    kv_size,
-                 uint32_t    n_pad,
-                 uint32_t    n_swa,
-           llama_swa_type    swa_type,
-                             /* recurrent */
-                ggml_type    type_r,
-                ggml_type    type_s,
-                 uint32_t    rs_size,
-                             /* common */
-                 uint32_t    n_seq_max,
-                     bool    offload,
-                     bool    unified,
-                             /* layer filters */
-          layer_filter_cb && filter_attn = nullptr,
-          layer_filter_cb && filter_recr = nullptr);
+                ggml_type   type_k,
+                ggml_type   type_v,
+                     bool   v_trans,
+                 uint32_t   kv_size,
+                 uint32_t   n_pad,
+                 uint32_t   n_swa,
+           llama_swa_type   swa_type,
+                            /* recurrent */
+                ggml_type   type_r,
+                ggml_type   type_s,
+                 uint32_t   rs_size,
+                            /* common */
+                 uint32_t   n_seq_max,
+                     bool   offload,
+                     bool   unified,
+                            /* layer filters */
+    const layer_filter_cb & filter_attn = nullptr,
+    const layer_filter_cb & filter_recr = nullptr);
 
     ~llama_memory_hybrid() = default;
 
index 849675c418891d9da3436d5c2a6a035cb7b95ce9..08716ed91aed124fcb71c70fec4be6d2ec8e94c7 100644 (file)
 //
 
 llama_memory_recurrent::llama_memory_recurrent(
-        const llama_model &  model,
-          layer_filter_cb && filter,
-                ggml_type    type_r,
-                ggml_type    type_s,
-                     bool    offload,
-                 uint32_t    mem_size,
-                 uint32_t    n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
+        const llama_model & model,
+                ggml_type   type_r,
+                ggml_type   type_s,
+                     bool   offload,
+                 uint32_t   mem_size,
+                 uint32_t   n_seq_max,
+    const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) {
     const int32_t n_layer = hparams.n_layer;
 
     head = 0;
index c8e8623602f78ad683d24b7c144190976b2d6287..c4daf00495bc2a43192914e4b19dbdf3079708af 100644 (file)
 //       see the implementation of llama_kv_cache_context_i for an example how to do it
 class llama_memory_recurrent : public llama_memory_i {
 public:
-
-    // this callback is used to filter out layers that should not be included in the cache
-    using layer_filter_cb = std::function<bool(int32_t il)>;
-
     llama_memory_recurrent(
-            const llama_model &  model,
-              layer_filter_cb && filter,
-                    ggml_type    type_r,
-                    ggml_type    type_s,
-                         bool    offload,
-                     uint32_t    mem_size,
-                     uint32_t    n_seq_max);
+            const llama_model & model,
+                    ggml_type   type_r,
+                    ggml_type   type_s,
+                         bool   offload,
+                     uint32_t   mem_size,
+                     uint32_t   n_seq_max,
+        const layer_filter_cb & filter);
 
     ~llama_memory_recurrent() = default;
 
index 94d858bccc2e0393fd50767246ebd5d12e203a7f..ccd1f073b0848c28920dfcd909c753e5482f4c3a 100644 (file)
@@ -3,6 +3,7 @@
 #include "llama.h"
 
 #include <memory>
+#include <functional>
 
 struct llama_ubatch;
 
@@ -64,6 +65,13 @@ using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
 // general concept of LLM memory
 // the KV cache is a type of LLM memory, but there can be other types
 struct llama_memory_i {
+    // this callback is used to filter out layers that should not be included in the cache
+    using layer_filter_cb = std::function<bool(int32_t il)>;
+
+    // this callback is used to specify which layers should reuse memory from other layers
+    // return negative value to indicate that the layer il should not reuse memory
+    using layer_reuse_cb = std::function<int32_t(int32_t il)>;
+
     virtual ~llama_memory_i() = default;
 
     // split the input batch into a set of ubatches and verify that they can fit into the cache
index d5148f7df36ed11cf98d4c8d713010cddbee2f9d..7d3429617bef98d2ccaa679351c6fb85f72517f9 100644 (file)
@@ -1115,6 +1115,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
                 hparams.set_swa_pattern(5);
 
+                hparams.n_layer_kv_from_start     = 20;
                 hparams.rope_freq_base_train_swa  = 10000.0f;
                 hparams.rope_freq_scale_train_swa = 1.0f;
                 hparams.f_attention_scale         = 1.0f;
@@ -1474,12 +1475,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 // Expert gating function (GLM-4.5 uses sigmoid)
                 ml.get_key(LLM_KV_EXPERT_GATING_FUNC,          hparams.expert_gating_func, false);
                 if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
-                    hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
+                    hparams.expert_gating_func =  LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
                 }
 
                 // NextN/MTP parameters
                 ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS,        hparams.nextn_predict_layers, false);
 
+                // TODO: when MTP is implemented, this should probably be updated if needed
+                hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
+
                 switch (hparams.n_layer) {
                     case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer)
                     case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer)
@@ -10524,7 +10528,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
     const int64_t n_embd_altup;
     const int64_t n_altup;
     const int     i_altup_act;
-    const int     n_layer_kv = 20; // number of layers having KV [KV_REUSE]
     const int     n_layer_sparsity = 10; // number of layers using activation sparsity
     const float   f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
 
@@ -10574,8 +10577,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
 
         for (int il = 0; il < n_layer; ++il) {
             // this block is made to be closely resemble Gemma3p5DecoderLayer on python code
-            const bool has_kv = (il < n_layer_kv);
-
             const float freq_base_l  = model.get_rope_freq_base (cparams, il);
             const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
 
@@ -10595,7 +10596,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
             ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]
 
             // self-attention
-            if (has_kv) {
+            if (hparams.has_kv(il)) {
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
@@ -10635,7 +10636,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
             } else {
-                // no KV layers
+                // reuse KV cache of earlier layers
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
@@ -18256,12 +18257,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                 if (llm_arch_is_recurrent(arch)) {
                     res = new llama_memory_recurrent(
                             *this,
-                            nullptr,
                             GGML_TYPE_F32,
                             GGML_TYPE_F32,
                             cparams.offload_kqv,
                             std::max((uint32_t) 1, cparams.n_seq_max),
-                            cparams.n_seq_max);
+                            cparams.n_seq_max,
+                            nullptr);
                 } else if (llm_arch_is_hybrid(arch)) {
                     const auto padding = llama_kv_cache::get_padding(cparams);
 
@@ -18302,6 +18303,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
 
                     LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
 
+                    llama_memory_i::layer_reuse_cb reuse = nullptr;
+
+                    if (arch == LLM_ARCH_GEMMA3N) {
+                        reuse = [&](int32_t il) {
+                            if (il >= (int32_t) hparams.n_layer_kv_from_start) {
+                                return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1);
+                            }
+
+                            return -1;
+                        };
+                    }
+
                     if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
                         GGML_ASSERT(hparams.is_swa_any());
 
@@ -18316,13 +18329,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                                 n_ctx_per_stream,
                                 cparams.n_seq_max,
                                 cparams.n_ubatch,
-                                padding);
+                                padding,
+                                nullptr,
+                                reuse);
                     } else {
                         GGML_ASSERT(!hparams.is_swa_any());
 
                         res = new llama_kv_cache(
                                 *this,
-                                nullptr,
                                 params.type_k,
                                 params.type_v,
                                 !cparams.flash_attn,
@@ -18332,7 +18346,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                                 cparams.n_seq_max,
                                 padding,
                                 hparams.n_swa,
-                                hparams.swa_type);
+                                hparams.swa_type,
+                                nullptr,
+                                nullptr);
                     }
                 }
             }