]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kv-cache : fix SWA checks + disable cacheless iSWA (#15811)
authorGeorgi Gerganov <redacted>
Fri, 5 Sep 2025 07:39:22 +0000 (10:39 +0300)
committerGitHub <redacted>
Fri, 5 Sep 2025 07:39:22 +0000 (10:39 +0300)
ggml-ci

src/llama-graph.cpp
src/llama-hparams.cpp
src/llama-hparams.h
src/llama-kv-cache-iswa.cpp
src/llama-kv-cache.cpp
src/llama-kv-cache.h
src/llama-memory-hybrid.cpp
src/llama-memory-hybrid.h
src/llama-model.cpp

index 7ce2960eb311b22c9b06f0d7678b507868aef1ef..4abb6008dd18490af90751fbaa4a3145c1b6ace0 100644 (file)
@@ -297,6 +297,9 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
 
     float * data = (float *) kq_mask->data;
 
+    // [TAG_NO_CACHE_ISWA]
+    GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
+
     for (int h = 0; h < 1; ++h) {
         for (int i1 = 0; i1 < n_tokens; ++i1) {
             const llama_seq_id s1 = ubatch->seq_id[i1][0];
@@ -315,9 +318,10 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
                         continue; // skip future tokens for causal attention
                     }
 
-                    if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
-                        continue; // skip masked tokens for SWA
-                    }
+                    // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
+                    //if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
+                    //    continue; // skip masked tokens for SWA
+                    //}
 
                     // TODO: reimplement this like in llama_kv_cache_unified
                     if (hparams.use_alibi) {
index 4b7a65362b9f04d17d0300789e69902a8c302a1c..c04ac58f1af4ba3746c97ddefcfd723c390ad027 100644 (file)
@@ -180,7 +180,7 @@ uint32_t llama_hparams::n_layer_kv() const {
     return res;
 }
 
-bool llama_hparams::is_masked_swa(llama_pos p0, llama_pos p1) const {
+bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
     assert(p0 >= 0 && p1 >= 0);
 
     switch (swa_type) {
index 06d1e51db055234e7a5ada4bfa2a2f8fce4a084f..89f5c7ab65dce2ddf2c69ef871181a421a5efbc7 100644 (file)
@@ -229,7 +229,10 @@ struct llama_hparams {
     // number of layers for which has_kv() returns true
     uint32_t n_layer_kv() const;
 
-    bool is_masked_swa(llama_pos p0, llama_pos p1) const;
+    // note that this function uses different SWA parameters from those in the hparams
+    // TODO: think of a better place for this function
+    // TODO: pack the SWA params in a struct?
+    static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
 };
 
 static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
index 7c51a1981e05e66688ef0a6edd3afcc237802fca..d7342914c6b7cbf17eeb8cae258b12413e16d6e7 100644 (file)
@@ -60,14 +60,14 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
     kv_base = std::make_unique<llama_kv_cache>(
             model, type_k, type_v,
             v_trans, offload, unified, size_base, n_seq_max, n_pad,
-            0, filter_base, reuse);
+            0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
 
     LLAMA_LOG_INFO("%s: creating     SWA KV cache, size = %u cells\n", __func__, size_swa);
 
     kv_swa = std::make_unique<llama_kv_cache>(
             model, type_k, type_v,
             v_trans, offload, unified, size_swa, n_seq_max, n_pad,
-            hparams.n_swa, filter_swa, reuse);
+            hparams.n_swa, hparams.swa_type, filter_swa, reuse);
 }
 
 void llama_kv_cache_iswa::clear(bool data) {
index 1564faedf4eb48658ca6c9ee5c979a01513d9efd..ae35f74201e9cccd51a04102acf5cfac43929d6c 100644 (file)
@@ -27,10 +27,11 @@ llama_kv_cache::llama_kv_cache(
                  uint32_t   n_seq_max,
                  uint32_t   n_pad,
                  uint32_t   n_swa,
+           llama_swa_type   swa_type,
     const layer_filter_cb & filter,
     const  layer_reuse_cb & reuse) :
     model(model), hparams(model.hparams), v_trans(v_trans),
-    n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa) {
+    n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
 
     GGML_ASSERT(kv_size % n_pad == 0);
 
@@ -1392,7 +1393,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
 }
 
 bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
-    return hparams.is_masked_swa(p0, p1);
+    return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
 }
 
 void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
index 55ee355b22928059696b59925fd43848d5adfb3a..b545bf5b9cf7113867688c0ff797d94ccf4f321c 100644 (file)
@@ -89,6 +89,7 @@ public:
                      uint32_t   n_seq_max,
                      uint32_t   n_pad,
                      uint32_t   n_swa,
+               llama_swa_type   swa_type,
         const layer_filter_cb & filter,
         const  layer_reuse_cb & reuse);
 
@@ -211,6 +212,9 @@ private:
     // env: LLAMA_KV_CACHE_DEBUG
     int debug = 0;
 
+    // this is the SWA type of the cache - not to be confused with the model SWA type
+    const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
+
     std::vector<ggml_context_ptr>        ctxs;
     std::vector<ggml_backend_buffer_ptr> bufs;
 
index 38f92da11820d5ce1275968adb7f90433169f4f6..ba61ebaa885feffc62ac012563736f59fa4325eb 100644 (file)
@@ -17,6 +17,7 @@ llama_memory_hybrid::llama_memory_hybrid(
                  uint32_t   kv_size,
                  uint32_t   n_pad,
                  uint32_t   n_swa,
+           llama_swa_type   swa_type,
                             /* recurrent */
                 ggml_type   type_r,
                 ggml_type   type_s,
@@ -40,6 +41,7 @@ llama_memory_hybrid::llama_memory_hybrid(
         n_seq_max,
         n_pad,
         n_swa,
+        swa_type,
         filter_attn == nullptr ?
             [&](int32_t il) { return !hparams.is_recurrent(il); }
             : filter_attn,
index 0eb63f5ef86d7ee2bccb536b497e16672e7cbb7c..11a35651782974023d48be9ddba531c0c79e0a26 100644 (file)
@@ -27,6 +27,7 @@ public:
                  uint32_t   kv_size,
                  uint32_t   n_pad,
                  uint32_t   n_swa,
+           llama_swa_type   swa_type,
                             /* recurrent */
                 ggml_type   type_r,
                 ggml_type   type_s,
index afc4cb48098d21e0933813f9342de8194ac3fbf8..1813f06d7b308b7014d89ca8cd2c72851c3a3d57 100644 (file)
@@ -11084,7 +11084,8 @@ struct llm_build_gemma_embedding_iswa : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_no_cache();
+        // TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA]
+        auto * inp_attn = build_attn_inp_kv_iswa();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -18632,7 +18633,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
         case LLM_ARCH_NOMIC_BERT_MOE:
         case LLM_ARCH_NEO_BERT:
         case LLM_ARCH_WAVTOKENIZER_DEC:
-        case LLM_ARCH_GEMMA_EMBEDDING:
+        //case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA]
         case LLM_ARCH_DREAM:
         case LLM_ARCH_LLADA:
             {
@@ -18681,6 +18682,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                         /* attn_kv_size      */ cparams.n_ctx,
                         /* attn_n_pad        */ padding,
                         /* attn_n_swa        */ hparams.n_swa,
+                        /* attn_swa_type     */ hparams.swa_type,
                         /* recurrent_type_k  */ GGML_TYPE_F32,
                         /* recurrent_type_v  */ GGML_TYPE_F32,
                         /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
@@ -18750,6 +18752,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                                 cparams.n_seq_max,
                                 padding,
                                 hparams.n_swa,
+                                hparams.swa_type,
                                 nullptr,
                                 nullptr);
                     }