]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : context checkpointing for hybrid and recurrent models (#16382)
authorddh0 <redacted>
Fri, 3 Oct 2025 18:34:51 +0000 (13:34 -0500)
committerGitHub <redacted>
Fri, 3 Oct 2025 18:34:51 +0000 (21:34 +0300)
* initial commit for branch 3

* generalize `swa_checkpoint` to `ctx_checkpoint`

this extends `llama-server`'s SWA checkpointing logic to include
hybrid/recurrent models such as Jamba, Granite

* oops

* disable debug prints

* keep backwards compat with `--swa-checkpoints`

Co-authored-by: Georgi Gerganov <redacted>
* update prompt re-processing message

* fix off-by-one error per GG

* keep `seq_rm` log per GG

Co-authored-by: Georgi Gerganov <redacted>
* server : fix checkpoint logic to support recurrent caches

* server : cleanup and fixes

---------

Co-authored-by: Georgi Gerganov <redacted>
common/arg.cpp
common/common.h
include/llama.h
src/llama-kv-cache-iswa.cpp
src/llama-memory-hybrid.cpp
src/llama-memory-recurrent.cpp
src/llama-model.cpp
tools/server/server.cpp

index cbca8b5ac5abb6f8962d6a8766d92f6736ff1e2d..577048c201b7692bdb9586c188038da2509bafa9 100644 (file)
@@ -1932,13 +1932,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         }
     ).set_env("LLAMA_ARG_SWA_FULL"));
     add_opt(common_arg(
-        {"--swa-checkpoints"}, "N",
-        string_format("max number of SWA checkpoints per slot to create (default: %d)\n"
-            "[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints),
+        {"--ctx-checkpoints", "--swa-checkpoints"}, "N",
+        string_format("max number of context checkpoints to create per slot (default: %d)\n"
+            "[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_ctx_checkpoints),
         [](common_params & params, int value) {
-            params.n_swa_checkpoints = value;
+            params.n_ctx_checkpoints = value;
         }
-    ).set_env("LLAMA_ARG_SWA_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
+    ).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
     add_opt(common_arg(
         {"--kv-unified", "-kvu"},
         string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
index 40c6847f32ddbb201500367da0f1feb68cbf091c..d33788bd100b236695130c94c88cb479fa7fbfc0 100644 (file)
@@ -424,7 +424,7 @@ struct common_params {
     int32_t timeout_write     = timeout_read; // http write timeout in seconds
     int32_t n_threads_http    = -1;           // number of threads to process HTTP requests (TODO: support threadpool)
     int32_t n_cache_reuse     = 0;            // min chunk size to reuse from the cache via KV shifting
-    int32_t n_swa_checkpoints = 3;            // max number of SWA checkpoints per slot
+    int32_t n_ctx_checkpoints = 3;            // max number of context checkpoints per slot
 
     std::string hostname      = "127.0.0.1";
     std::string public_path   = "";                                                                         // NOLINT
index 452d9ec5bf285425b8935bf209a6262dbd0cbade..8fc3d7db5a917424b8765061771d9adcec3431d7 100644 (file)
@@ -543,6 +543,9 @@ extern "C" {
     // Returns true if the model is recurrent (like Mamba, RWKV, etc.)
     LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
 
+    // Returns true if the model is hybrid (like Jamba, Granite, etc.)
+    LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);
+
     // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
     LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);
 
@@ -791,8 +794,12 @@ extern "C" {
                           size_t   n_token_capacity,
                           size_t * n_token_count_out);
 
+// for backwards-compat
 #define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1
 
+// work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba)
+#define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1
+
     typedef uint32_t llama_state_seq_flags;
 
     LLAMA_API size_t llama_state_seq_get_size_ext(
index 827302e6d25bd486562af65c345b21cbeae6ff5a..facba1d004012b503ebe22f35f022a4d5dfe1a8f 100644 (file)
@@ -220,7 +220,7 @@ bool llama_kv_cache_iswa::get_can_shift() const {
 }
 
 void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
-    if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
+    if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
         kv_base->state_write(io, seq_id, flags);
     }
 
@@ -228,7 +228,7 @@ void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id
 }
 
 void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
-    if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
+    if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
         kv_base->state_read(io, seq_id, flags);
     }
 
index abf652483c202971a9e3a36d0d200313bf59bc49..cb8832a353b11c4a4841e44cf823e0e952e60432 100644 (file)
@@ -175,17 +175,17 @@ std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid::memory_breakdo
 }
 
 void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
-    GGML_UNUSED(flags);
-
-    mem_attn->state_write(io, seq_id);
-    mem_recr->state_write(io, seq_id);
+    if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
+        mem_attn->state_write(io, seq_id, flags);
+    }
+    mem_recr->state_write(io, seq_id, flags);
 }
 
 void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
-    GGML_UNUSED(flags);
-
-    mem_attn->state_read(io, seq_id);
-    mem_recr->state_read(io, seq_id);
+    if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
+        mem_attn->state_read(io, seq_id, flags);
+    }
+    mem_recr->state_read(io, seq_id, flags);
 }
 
 llama_kv_cache * llama_memory_hybrid::get_mem_attn() const {
index 44645fcdd2d4824edd704e02c57d716db84cbd4d..e23e74982b2786b909b7b4afb5b1bd920913d3d5 100644 (file)
@@ -136,6 +136,7 @@ void llama_memory_recurrent::clear(bool data) {
 }
 
 bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+    //printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1);
     uint32_t new_head = size;
 
     if (p0 < 0) {
@@ -156,7 +157,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
         if (tail_id >= 0) {
             const auto & cell = cells[tail_id];
             // partial intersection is invalid
-            if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
+            if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) {
+                //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
                 return false;
             }
             // invalidate tails which will be cleared
@@ -167,6 +169,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
     } else {
         // seq_id is negative, then the range should include everything or nothing
         if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
+            //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: `seq_id` is negative, so returning false\n");
             return false;
         }
     }
index cce77a854bb2f92c61e8abfad98877e9e9e0d1ce..4c2d481a41d42247e3b8e2bcef112144cfec538c 100644 (file)
@@ -20151,6 +20151,10 @@ bool llama_model_is_recurrent(const llama_model * model) {
     return llm_arch_is_recurrent(model->arch);
 }
 
+bool llama_model_is_hybrid(const llama_model * model) {
+    return llm_arch_is_hybrid(model->arch);
+}
+
 bool llama_model_is_diffusion(const llama_model * model) {
     return llm_arch_is_diffusion(model->arch);
 }
index 6062904a8c7c00df5ab5133348a3465dfe1247b0..a21147613db0021898ecac4ff2ceb04396b15f62 100644 (file)
@@ -764,7 +764,7 @@ struct completion_token_output {
     }
 };
 
-struct swa_checkpoint {
+struct ctx_checkpoint {
     llama_pos pos_min;
     llama_pos pos_max;
 
@@ -1460,7 +1460,7 @@ struct server_slot {
 
     std::vector<completion_token_output> generated_token_probs;
 
-    std::vector<swa_checkpoint> swa_checkpoints;
+    std::vector<ctx_checkpoint> ctx_checkpoints;
 
     bool has_next_token = true;
     bool has_new_line   = false;
@@ -3541,7 +3541,11 @@ struct server_context {
                                 slot.n_past = 0;
                             }
 
-                            const auto n_swa = llama_model_n_swa(model);
+                            // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
+                            const auto n_swa = std::max(1, llama_model_n_swa(model));
+
+                            // the largest pos_min required for a checkpoint to be useful
+                            const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
 
                             if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
                                 const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
@@ -3550,66 +3554,62 @@ struct server_context {
                                     GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
                                 }
 
-                                const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
-
                                 if (pos_min > pos_min_thold) {
                                     SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
 
-                                    // search for a SWA checkpoint
+                                    // search for a context checkpoint
                                     const auto it = std::find_if(
-                                        slot.swa_checkpoints.rbegin(),
-                                        slot.swa_checkpoints.rend(),
+                                        slot.ctx_checkpoints.rbegin(),
+                                        slot.ctx_checkpoints.rend(),
                                         [&](const auto & cur) {
-                                            return cur.pos_min <= pos_min_thold;
+                                            // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
+                                            return cur.pos_min < pos_min_thold;
                                         }
                                     );
 
-                                    bool do_reset = it == slot.swa_checkpoints.rend();
+                                    bool do_reset = it == slot.ctx_checkpoints.rend();
+                                    //printf("[DEBUG] `do_reset` was set to `%s`\n", do_reset ? "true" : "false");
 
                                     if (!do_reset) {
-                                        // restore the checkpoint
-                                        const size_t swa_size = it->data.size();
-                                        const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
+                                        // restore the context checkpoint
+                                        const size_t ctx_checkpoint_size = it->data.size();
+                                        const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), ctx_checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
 
-                                        if (n != swa_size) {
-                                            SLT_ERR(slot, "failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
+                                        if (n != ctx_checkpoint_size) {
+                                            SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
                                             do_reset = true;
+                                            //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
                                         } else {
-                                            slot.n_past = std::min(slot.n_past, it->pos_max);
-
-                                            SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
+                                            slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
+                                            SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
                                         }
                                     }
 
                                     if (do_reset) {
-                                        SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
+                                        SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
                                                 "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
-
                                         slot.n_past = 0;
-                                        slot.swa_checkpoints.clear();
                                     }
                                 }
                             }
 
-                            if (n_swa > 0) {
-                                const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
-
+                            {
                                 // erase any checkpoints with pos_min > pos_min_thold
-                                for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) {
-                                    const auto & cur = slot.swa_checkpoints[i];
+                                for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) {
+                                    const auto & cur = slot.ctx_checkpoints[i];
                                     if (cur.pos_min > pos_min_thold) {
-                                        slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i);
-
-                                        SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
+                                        SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
+                                        slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i);
                                     }
                                 }
                             }
                         }
 
+                        // [TAG_PROMPT_LOGITS]
                         if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
-                            SLT_WRN(slot, "need to evaluate at least 1 token for each active slot, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
-
+                            SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens);
                             slot.n_past--;
+                            SLT_WRN(slot, "n_past was set to %d\n", slot.n_past);
                         }
 
                         slot.n_prompt_tokens_cache     = slot.n_past;
@@ -3623,9 +3623,9 @@ struct server_context {
                         }
                     }
 
-                    // keep only the common part
+                    // truncate any tokens that are beyond n_past for this slot
                     if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) {
-                        // could not partially delete (likely using a non-Transformer model)
+                        SLT_WRN(slot, "failed to truncate tokens beyond n_past = %d\n", slot.n_past);
                         llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
 
                         // there is no common part left
@@ -3633,7 +3633,7 @@ struct server_context {
                         slot.n_prompt_tokens_cache = 0;
                     }
 
-                    SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
+                    SLT_INF(slot, "n_past = %d, memory_seq_rm [%d, end)\n", slot.n_past, slot.n_past);
 
                     // remove the non-common part from the cache
                     slot.cache_tokens.keep_first(slot.n_past);
@@ -3854,37 +3854,38 @@ struct server_context {
                     // prompt evaluated for next-token prediction
                     slot.state = SLOT_STATE_GENERATING;
 
-                    // make a checkpoint with the SWA memory
-                    // checkpoints are needed only if we are not using "--swa-full"
-                    if (llama_model_n_swa(model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0) {
-                        if (slot.swa_checkpoints.size() >= (size_t) params_base.n_swa_checkpoints) {
-                            {
-                                const auto & cur = slot.swa_checkpoints.back();
-
-                                SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n",
-                                        cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
-                            }
-
-                            slot.swa_checkpoints.erase(slot.swa_checkpoints.begin());
+                    // make a checkpoint of the parts of the memory that cannot be rolled back.
+                    // checkpoints are created only if:
+                    // - the model uses SWA and we are not using `swa_full`
+                    // - the model architecture is marked as recurrent or hybrid
+                    //
+                    // TODO: try to make this conditional on the context or the memory module, instead of the model type
+                    const bool do_checkpoint =
+                        (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) ||
+                        (llama_model_n_swa(model) > 0 && !params_base.swa_full);
+
+                    if (do_checkpoint && params_base.n_ctx_checkpoints > 0) {
+                        while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
+                            // make room for the new checkpoint, if needed
+                            const auto & cur = slot.ctx_checkpoints.front();
+                            SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
+                                    cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
+
+                            slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
                         }
 
-                        const size_t swa_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
+                        const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
 
-                        auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{
+                        auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
                             /*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
                             /*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
-                            /*.data    = */ std::vector<uint8_t>(swa_size),
+                            /*.data    = */ std::vector<uint8_t>(checkpoint_size),
                         });
 
-                        llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
-
-                        float size_total = 0.0f;
-                        for (const auto & checkpoint : slot.swa_checkpoints) {
-                            size_total += (float) checkpoint.data.size() / 1024 / 1024;
-                        }
+                        llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
 
-                        SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d (%.3f MiB)\n",
-                                cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024, (int) slot.swa_checkpoints.size(), params_base.n_swa_checkpoints, size_total);
+                        SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
+                                (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
                     }
                 } else if (slot.state != SLOT_STATE_GENERATING) {
                     continue; // continue loop of slots