]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : make 2 checkpoints near the end of the prompt (#20288)
authorGeorgi Gerganov <redacted>
Tue, 10 Mar 2026 12:28:23 +0000 (14:28 +0200)
committerGitHub <redacted>
Tue, 10 Mar 2026 12:28:23 +0000 (14:28 +0200)
* server : make 2 checkpoints near the end of the prompt

* cont : adjust checkpoints

tools/server/server-context.cpp

index 3541d910d856af567f2c570135c4f02f39fde303..b86e7e608e0147c77a0c7b11b4c250a94371ded1 100644 (file)
@@ -2530,9 +2530,24 @@ private:
                         slot.n_prompt_tokens_processed++;
 
                         // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
-                        const int n_last = std::min(n_batch, 512);
-                        if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) {
-                            break;
+                        // create checkpoints that many tokens before the end of the prompt:
+                        //  - 4 + n_ubatch
+                        //  - 4
+                        // ref: https://github.com/ggml-org/llama.cpp/pull/20288
+                        {
+                            static const int checkpoint_offsets[] = {4 + n_ubatch, 4};
+
+                            bool should_break = false;
+                            for (int offset : checkpoint_offsets) {
+                                const int n_last = std::min(n_batch, offset);
+                                if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) {
+                                    should_break = true;
+                                    break;
+                                }
+                            }
+                            if (should_break) {
+                                break;
+                            }
                         }
                     }
 
@@ -2554,18 +2569,27 @@ private:
                         slot.init_sampler();
                         SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
                     } else {
-                        // only do non-end checkpoints if the "checkpoint every n tokens" option is set
-                        do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0;
-                        if (do_checkpoint) {
-                            llama_pos last_checkpoint = 0;
-                            if (!slot.prompt.checkpoints.empty()) {
-                                last_checkpoint = slot.prompt.checkpoints.back().n_tokens;
-                            }
-                            do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt;
+                        if (slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch) {
+                            // near the end of the prompt
+                            do_checkpoint = do_checkpoint && true;
+                        } else {
+                            // only do non-end checkpoints if the "checkpoint every n tokens" option is set
+                            do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0;
+
                             if (do_checkpoint) {
-                                SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens());
+                                llama_pos last_checkpoint = 0;
+                                if (!slot.prompt.checkpoints.empty()) {
+                                    last_checkpoint = slot.prompt.checkpoints.back().n_tokens;
+                                }
+
+                                do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt;
+
+                                if (do_checkpoint) {
+                                    SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens());
+                                }
                             }
                         }
+
                         SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
                     }