slot.n_prompt_tokens_processed++;
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
- const int n_last = std::min(n_batch, 512);
- if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) {
- break;
+ // create checkpoints that many tokens before the end of the prompt:
+ // - 4 + n_ubatch
+ // - 4
+ // ref: https://github.com/ggml-org/llama.cpp/pull/20288
+ {
+ static const int checkpoint_offsets[] = {4 + n_ubatch, 4};
+
+ bool should_break = false;
+ for (int offset : checkpoint_offsets) {
+ const int n_last = std::min(n_batch, offset);
+ if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) {
+ should_break = true;
+ break;
+ }
+ }
+ if (should_break) {
+ break;
+ }
}
}
slot.init_sampler();
SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
} else {
- // only do non-end checkpoints if the "checkpoint every n tokens" option is set
- do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0;
- if (do_checkpoint) {
- llama_pos last_checkpoint = 0;
- if (!slot.prompt.checkpoints.empty()) {
- last_checkpoint = slot.prompt.checkpoints.back().n_tokens;
- }
- do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt;
+ if (slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch) {
+ // near the end of the prompt
+ do_checkpoint = do_checkpoint && true;
+ } else {
+ // only do non-end checkpoints if the "checkpoint every n tokens" option is set
+ do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0;
+
if (do_checkpoint) {
- SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens());
+ llama_pos last_checkpoint = 0;
+ if (!slot.prompt.checkpoints.empty()) {
+ last_checkpoint = slot.prompt.checkpoints.back().n_tokens;
+ }
+
+ do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt;
+
+ if (do_checkpoint) {
+ SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens());
+ }
}
}
+
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
}