]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : improve context checkpoint logic (#16440)
authorGeorgi Gerganov <redacted>
Wed, 8 Oct 2025 07:57:29 +0000 (10:57 +0300)
committerGitHub <redacted>
Wed, 8 Oct 2025 07:57:29 +0000 (10:57 +0300)
src/llama-memory-recurrent.cpp
tools/server/server.cpp

index 9402d9cb8df9fbe46506d6e813a3949760b7ae31..d67f5a5f47b87c2440a2a20c281ac71692498e81 100644 (file)
@@ -861,9 +861,12 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
 bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
     if (dest_seq_id != -1) {
         // single sequence
-
         seq_rm(dest_seq_id, -1, -1);
 
+        if (cell_count == 0) {
+            return true;
+        }
+
         llama_batch_allocr balloc(hparams.n_pos_per_embd());
 
         llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
index 307653764cc759733c678fe6efe0c7ba34838f99..5980d43b49a4be152921ec221453eb0baedfc984 100644 (file)
@@ -3676,6 +3676,20 @@ struct server_context {
                         alora_disabled_id = enabled_loras[0];
                     }
 
+                    bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
+
+                    // make a checkpoint of the parts of the memory that cannot be rolled back.
+                    // checkpoints are created only if:
+                    // - the model uses SWA and we are not using `swa_full`
+                    // - the model architecture is marked as recurrent or hybrid
+                    //
+                    // TODO: try to make this conditional on the context or the memory module, instead of the model type
+                    do_checkpoint = do_checkpoint && (
+                            llama_model_is_recurrent(model) ||
+                            llama_model_is_hybrid(model) ||
+                            (llama_model_n_swa(model) > 0 && !params_base.swa_full)
+                            );
+
                     // add prompt tokens for processing in the current batch
                     while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
                         // get next token to process
@@ -3700,6 +3714,11 @@ struct server_context {
 
                         slot.n_prompt_tokens_processed++;
                         slot.n_past++;
+
+                        // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
+                        if (do_checkpoint && slot.n_prompt_tokens - slot.n_past == 64) {
+                            break;
+                        }
                     }
 
                     // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
@@ -3730,6 +3749,39 @@ struct server_context {
                         slot.i_batch   = batch.n_tokens - 1;
 
                         SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
+
+                        const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
+                        const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
+
+                        // no need for empty or small checkpoints
+                        do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
+
+                        // no need to create checkpoints that are too close together
+                        do_checkpoint = do_checkpoint && (slot.ctx_checkpoints.empty() || pos_max > slot.ctx_checkpoints.back().pos_max + 64);
+
+                        if (do_checkpoint) {
+                            while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
+                                // make room for the new checkpoint, if needed
+                                const auto & cur = slot.ctx_checkpoints.front();
+                                SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
+                                        cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
+
+                                slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
+                            }
+
+                            const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
+
+                            auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
+                                /*.pos_min = */ pos_min,
+                                /*.pos_max = */ pos_max,
+                                /*.data    = */ std::vector<uint8_t>(checkpoint_size),
+                            });
+
+                            llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
+
+                            SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
+                                    (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
+                        }
                     }
                 }
 
@@ -3853,40 +3905,6 @@ struct server_context {
 
                     // prompt evaluated for next-token prediction
                     slot.state = SLOT_STATE_GENERATING;
-
-                    // make a checkpoint of the parts of the memory that cannot be rolled back.
-                    // checkpoints are created only if:
-                    // - the model uses SWA and we are not using `swa_full`
-                    // - the model architecture is marked as recurrent or hybrid
-                    //
-                    // TODO: try to make this conditional on the context or the memory module, instead of the model type
-                    const bool do_checkpoint =
-                        (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) ||
-                        (llama_model_n_swa(model) > 0 && !params_base.swa_full);
-
-                    if (do_checkpoint && params_base.n_ctx_checkpoints > 0) {
-                        while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
-                            // make room for the new checkpoint, if needed
-                            const auto & cur = slot.ctx_checkpoints.front();
-                            SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
-                                    cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
-
-                            slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
-                        }
-
-                        const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
-
-                        auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
-                            /*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
-                            /*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
-                            /*.data    = */ std::vector<uint8_t>(checkpoint_size),
-                        });
-
-                        llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
-
-                        SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
-                                (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
-                    }
                 } else if (slot.state != SLOT_STATE_GENERATING) {
                     continue; // continue loop of slots
                 }