]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Checkpoint every n tokens: squash (#20087)
authorPiotr Wilkin (ilintar) <redacted>
Fri, 6 Mar 2026 10:39:26 +0000 (11:39 +0100)
committerGitHub <redacted>
Fri, 6 Mar 2026 10:39:26 +0000 (11:39 +0100)
common/arg.cpp
common/common.h
tools/server/server-context.cpp

index cd73d9642043f498035d690c2704722bdd75bca2..0d8561dbb3c76c489215a497cfd8219640729b00 100644 (file)
@@ -1279,13 +1279,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         }
     ).set_env("LLAMA_ARG_SWA_FULL"));
     add_opt(common_arg(
-        {"--ctx-checkpoints", "--swa-checkpoints"}, "N",
+        {"-ctxcp", "--ctx-checkpoints", "--swa-checkpoints"}, "N",
         string_format("max number of context checkpoints to create per slot (default: %d)"
             "[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_ctx_checkpoints),
         [](common_params & params, int value) {
             params.n_ctx_checkpoints = value;
         }
     ).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
+    add_opt(common_arg(
+        {"-cpent", "--checkpoint-every-n-tokens"}, "N",
+        string_format("create a checkpoint every n tokens during prefill (processing), -1 to disable (default: %d)", params.checkpoint_every_nt),
+        [](common_params & params, int value) {
+            params.checkpoint_every_nt = value;
+        }
+    ).set_env("LLAMA_ARG_CHECKPOINT_EVERY_NT").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
     add_opt(common_arg(
         {"-cram", "--cache-ram"}, "N",
         string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)"
index 3c09cdf04055c791d2b219aa80400a2a734183a0..3e1b23f5d46adc8ac7eda25c7adba5f4fc017f20 100644 (file)
@@ -516,14 +516,15 @@ struct common_params {
     std::string cls_sep    = "\t";  // separator of classification sequences
 
     // server params
-    int32_t port              = 8080;         // server listens on this network port
-    int32_t timeout_read      = 600;          // http read timeout in seconds
-    int32_t timeout_write     = timeout_read; // http write timeout in seconds
-    int32_t n_threads_http    = -1;           // number of threads to process HTTP requests (TODO: support threadpool)
-    int32_t n_cache_reuse     = 0;            // min chunk size to reuse from the cache via KV shifting
-    bool    cache_prompt      = true;         // whether to enable prompt caching
-    int32_t n_ctx_checkpoints = 8;            // max number of context checkpoints per slot
-    int32_t cache_ram_mib     = 8192;         // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
+    int32_t port                = 8080;          // server listens on this network port
+    int32_t timeout_read        = 600;           // http read timeout in seconds
+    int32_t timeout_write       = timeout_read;  // http write timeout in seconds
+    int32_t n_threads_http      = -1;    // number of threads to process HTTP requests (TODO: support threadpool)
+    int32_t n_cache_reuse       = 0;     // min chunk size to reuse from the cache via KV shifting
+    bool    cache_prompt        = true;  // whether to enable prompt caching
+    int32_t n_ctx_checkpoints   = 32;     // max number of context checkpoints per slot
+    int32_t checkpoint_every_nt = 8192;   // make a checkpoint every n tokens during prefill
+    int32_t cache_ram_mib       = 8192;  // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
 
     std::string hostname      = "127.0.0.1";
     std::string public_path   = "";                                                                         // NOLINT
index aafed495020ccf60c05047c483eae49e8dd6688f..9dbd6d798a3d214aebbdb29f765f1a09d17b1a9b 100644 (file)
@@ -12,6 +12,7 @@
 #include "mtmd.h"
 #include "mtmd-helper.h"
 
+#include <algorithm>
 #include <cstddef>
 #include <cinttypes>
 #include <memory>
@@ -2348,8 +2349,10 @@ private:
                                     const auto it = std::find_if(
                                         slot.prompt.checkpoints.rbegin(),
                                         slot.prompt.checkpoints.rend(),
-                                        [&](const auto & cur) {
+                                        [&, func_name = __func__](const auto & cur) {
                                             // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
+                                            LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d] against %d...\n", 12,
+                                                func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, pos_min_thold);
                                             return cur.pos_min < pos_min_thold;
                                         }
                                     );
@@ -2533,47 +2536,65 @@ private:
                         slot.i_batch   = batch.n_tokens - 1;
 
                         slot.init_sampler();
+                        SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
+                    } else {
+                        // only do non-end checkpoints if the "checkpoint every n tokens" option is set
+                        do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0;
+                        if (do_checkpoint) {
+                            llama_pos last_checkpoint = 0;
+                            if (!slot.prompt.checkpoints.empty()) {
+                                last_checkpoint = slot.prompt.checkpoints.back().n_tokens;
+                            }
+                            do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt;
+                            if (do_checkpoint) {
+                                SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens());
+                            }
+                        }
+                        SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
+                    }
 
-                        const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
-                        const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
+                    const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
+                    const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
 
-                        // no need for empty or small checkpoints
-                        do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
+                    // no need for empty or small checkpoints
+                    do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
 
-                        // no need to create checkpoints that are too close together
-                        do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
+                    // no need to create checkpoints that are too close together
+                    do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
 
-                        // note: we create the checkpoint before calling llama_decode(), so the current batch is not
-                        //       yet processed and therefore it is not part of the checkpoint.
-                        if (do_checkpoint) {
-                            while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
-                                // make room for the new checkpoint, if needed
-                                const auto & cur = slot.prompt.checkpoints.front();
+                    // note: we create the checkpoint before calling llama_decode(), so the current batch is not
+                    //       yet processed and therefore it is not part of the checkpoint.
+                    if (do_checkpoint) {
+                        while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
+                            // make room for the new checkpoint, if needed
+                            const auto & cur = slot.prompt.checkpoints.front();
 
-                                SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
-                                        cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
+                            SLT_WRN(slot,
+                                    "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64
+                                    ", size = %.3f MiB)\n",
+                                    cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
 
-                                slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
-                            }
-
-                            const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
+                            slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
+                        }
 
-                            auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
-                                /*.pos_min  = */ pos_min,
-                                /*.pos_max  = */ pos_max,
-                                /*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens,
-                                /*.data     = */ std::vector<uint8_t>(checkpoint_size),
-                            });
+                        const size_t checkpoint_size =
+                            llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
 
-                            llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
+                        auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
+                            /*.pos_min  = */ pos_min,
+                            /*.pos_max  = */ pos_max,
+                            /*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens,
+                            /*.data     = */ std::vector<uint8_t>(checkpoint_size),
+                        });
 
-                            SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
-                                    (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
-                        }
+                        llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id,
+                                                     LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
 
-                        SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
-                    } else {
-                        SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
+                        SLT_WRN(slot,
+                                "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64
+                                ", size = %.3f MiB)\n",
+                                (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min,
+                                cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
                     }
                 }