]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
memory: Hybrid context shift (#17009) upstream/0.0.7011
authorGabe Goodhart <redacted>
Mon, 10 Nov 2025 15:14:23 +0000 (08:14 -0700)
committerGitHub <redacted>
Mon, 10 Nov 2025 15:14:23 +0000 (17:14 +0200)
* feat(memory): Only fail partial erasure of recurrent tail

The recurrent state is always assumed to be the state as of the last update
from the final token in the sequence. When doing a partial erasure, if the
range does not include the final token, the erasure can be considered a
success since any memory used for the sequence prior to the final token
(which is no memory) has been successfully removed.

There is one potential case that this doesn't address which is the pruning
of cache to remove sensitive data from the context. This wouldn't work for
attention cache partial removal (in the middle) either since the KV state
is linearly-dependent and states in later sequence positions would still be
based on the state from the sensitive data, even if that data is no longer
cached, so I don't think this is relevant, but it is worth noting that the
semantics of this change for a partial erasure in the middle of the cache
are essentially "my context is already compressed" and not "all trace of
the removed tokens has been removed."

https://github.com/ggml-org/llama.cpp/issues/16768
Branch: HybridContextShift-16768

Signed-off-by: Gabe Goodhart <redacted>
* fix(main): Check the output of seq_rm for prefix matching

This prefix matching is explicitly attempting to remove the tokens at the
end of the sequence that don't match. This is the operation that can't be
performed on a recurrent cache due to the state being updated in place, so
if this removal fails, we need to clear the whole cache.

https://github.com/ggml-org/llama.cpp/issues/16768
Branch: HybridContextShift-16768

Signed-off-by: Gabe Goodhart <redacted>
* fix(memory): Fix condition for partial erasure failure if p0 > pos

Signed-off-by: Gabe Goodhart <redacted>
Co-authored-by: compilade <redacted>
* style: Fix extra parens

Signed-off-by: Gabe Goodhart <redacted>
Co-authored-by: Georgi Gerganov <redacted>
* fix(main.cpp): Set n_matching_session_tokens to 0 on cache clear

https://github.com/ggml-org/llama.cpp/issues/16768
Branch: HybridContextShift-16768

Signed-off-by: Gabe Goodhart <redacted>
---------

Signed-off-by: Gabe Goodhart <redacted>
Co-authored-by: compilade <redacted>
Co-authored-by: Georgi Gerganov <redacted>
src/llama-memory-recurrent.cpp
tools/main/main.cpp

index 276e1697d466c6dd738c3335265735db7e0a8331..812bf2530491a74433b41c114601aa793cffd76a 100644 (file)
@@ -151,7 +151,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
         p1 = std::numeric_limits<llama_pos>::max();
     }
 
-    // models like Mamba or RWKV can't have a state partially erased
+    // models like Mamba or RWKV can't have a state partially erased at the end
+    // of the sequence because their state isn't preserved for previous tokens
     if (seq_id >= (int64_t) size) {
         // could be fatal
         return false;
@@ -160,8 +161,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
         int32_t & tail_id = cells[seq_id].tail;
         if (tail_id >= 0) {
             const auto & cell = cells[tail_id];
-            // partial intersection is invalid
-            if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) {
+            // partial intersection is invalid if it includes the final pos
+            if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
                 //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
                 return false;
             }
index 498e00e3a5e58deb9e2622f1f6165e9ad0f1c375..33e88623357934bd09e9e32b19de3afe5440c266 100644 (file)
@@ -354,7 +354,11 @@ int main(int argc, char ** argv) {
         }
 
         // remove any "future" tokens that we might have inherited from the previous session
-        llama_memory_seq_rm(mem, -1, n_matching_session_tokens, -1);
+        if (!llama_memory_seq_rm(mem, -1, n_matching_session_tokens, -1)) {
+            LOG_INF("%s: unable to resuse common prefix\n", __func__);
+            n_matching_session_tokens = 0;
+            llama_memory_seq_rm(mem, -1, -1, -1);
+        }
     }
 
     LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",