]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kv-cache : relax SWA masking condition (#14119)
authorGeorgi Gerganov <redacted>
Wed, 11 Jun 2025 13:48:45 +0000 (16:48 +0300)
committerGitHub <redacted>
Wed, 11 Jun 2025 13:48:45 +0000 (16:48 +0300)
ggml-ci

src/llama-kv-cache-unified.cpp

index b184735566a0a3cdbaba83c26ba05a1051b2150b..1a9f4e3159f94d83b03b18dd3393c875309aec97 100644 (file)
@@ -582,21 +582,15 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
             continue;
         }
 
-        // keep track of what the minimum sequence positions would be if we accept the ubatch
-        llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
-        for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
-            seq_pos_min[s] = cells.seq_pos_min(s);
-        }
-
         bool found = true;
         for (uint32_t i = 0; i < n_tokens; i++) {
-            const llama_pos    pos    = ubatch.pos[i];
-            const llama_seq_id seq_id = ubatch.seq_id[i][0];
+            //const llama_pos    pos    = ubatch.pos[i];
+            //const llama_seq_id seq_id = ubatch.seq_id[i][0];
 
             // can we use this cell? either:
             //  - the cell is empty
             //  - the cell is occupied only by one sequence:
-            //    - mask causally, if the sequence is the same as the one we are inserting
+            //    - (disabled) mask causally, if the sequence is the same as the one we are inserting
             //    - mask SWA, using current max pos for that sequence in the cache
             //                always insert in the cell with minimum pos
             bool can_use = cells.is_empty(head_cur + i);
@@ -604,21 +598,17 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
             if (!can_use && cells.seq_count(head_cur + i) == 1) {
                 const llama_pos pos_cell = cells.pos_get(head_cur + i);
 
-                // causal mask
-                if (cells.seq_has(head_cur + i, seq_id)) {
-                    can_use = pos_cell >= pos;
-                }
+                // (disabled) causal mask
+                // note: it's better to purge any "future" tokens beforehand
+                //if (cells.seq_has(head_cur + i, seq_id)) {
+                //    can_use = pos_cell >= pos;
+                //}
 
                 if (!can_use) {
                     const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
 
                     // SWA mask
-                    // note: we insert only in the cell with minimum pos in order to preserve the invariant that
-                    //       all positions between [pos_min, pos_max] for each sequence will be present in the cache
-                    //       ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
-                    if (pos_cell == seq_pos_min[seq_id_cell] &&
-                        is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
-                        seq_pos_min[seq_id_cell]++;
+                    if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
                         can_use = true;
                     }
                 }
@@ -646,8 +636,22 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
 }
 
 void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
+    // keep track of the max sequence position that we would overwrite with this ubatch
+    // for non-SWA cache, this would be always empty
+    llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
+    for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+        seq_pos_max_rm[s] = -1;
+    }
+
     for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
         if (!cells.is_empty(head_cur + i)) {
+            assert(cells.seq_count(head_cur + i) == 1);
+
+            const llama_seq_id seq_id = cells.seq_get(head_cur + i);
+            const llama_pos    pos    = cells.pos_get(head_cur + i);
+
+            seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
+
             cells.rm(head_cur + i);
         }
 
@@ -658,6 +662,22 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
         }
     }
 
+    // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
+    //       will be present in the cache. so we have to purge any position which is less than those we would overwrite
+    //       ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
+    for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+        if (seq_pos_max_rm[s] == -1) {
+            continue;
+        }
+
+        if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
+            LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
+                    __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
+
+            seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
+        }
+    }
+
     // move the head at the end of the slot
     head = head_cur + ubatch.n_tokens;
 }