]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
memory : handle kv_unified for hybrid models (#15050)
authorcompilade <redacted>
Sun, 3 Aug 2025 19:43:07 +0000 (15:43 -0400)
committerGitHub <redacted>
Sun, 3 Aug 2025 19:43:07 +0000 (21:43 +0200)
src/llama-memory-hybrid.cpp
src/llama-memory-hybrid.h
src/llama-model.cpp

index d8e2086c87514f34116709cc0f87682457a1f587..e98b4e354695987d712ded4ce6974186800fad90 100644 (file)
@@ -25,6 +25,7 @@ llama_memory_hybrid::llama_memory_hybrid(
                          /* common */
              uint32_t    n_seq_max,
                  bool    offload,
+                 bool    unified,
                          /* layer filters */
       layer_filter_cb && filter_attn,
       layer_filter_cb && filter_recr) :
@@ -38,7 +39,7 @@ llama_memory_hybrid::llama_memory_hybrid(
         type_v,
         v_trans,
         offload,
-        1,
+        unified,
         kv_size,
         n_seq_max,
         n_pad,
index 4ac318175785e50d410b32addf15e2674ef3a39b..c2d56cd541594381f40524f18ab72263481aef05 100644 (file)
@@ -39,6 +39,7 @@ public:
                              /* common */
                  uint32_t    n_seq_max,
                      bool    offload,
+                     bool    unified,
                              /* layer filters */
           layer_filter_cb && filter_attn = nullptr,
           layer_filter_cb && filter_recr = nullptr);
index 6b58fb8a059f490545dd23ca32b304bf16dd52a1..60a615c159a51c6a96ec3baa2deeefcf50d89a3b 100644 (file)
@@ -17598,6 +17598,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                         /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
                         /* n_seq_max         */ cparams.n_seq_max,
                         /* offload           */ cparams.offload_kqv,
+                        /* unified           */ cparams.kv_unified,
                         /* filter_attn       */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr,
                         /* filter_recr       */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr);
                 } else {