}
).set_env("LLAMA_ARG_SWA_FULL"));
add_opt(common_arg(
- {"--ctx-checkpoints", "--swa-checkpoints"}, "N",
+ {"-ctxcp", "--ctx-checkpoints", "--swa-checkpoints"}, "N",
string_format("max number of context checkpoints to create per slot (default: %d)"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_ctx_checkpoints),
[](common_params & params, int value) {
params.n_ctx_checkpoints = value;
}
).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
+ add_opt(common_arg(
+ {"-cpent", "--checkpoint-every-n-tokens"}, "N",
+ string_format("create a checkpoint every n tokens during prefill (processing), -1 to disable (default: %d)", params.checkpoint_every_nt),
+ [](common_params & params, int value) {
+ params.checkpoint_every_nt = value;
+ }
+ ).set_env("LLAMA_ARG_CHECKPOINT_EVERY_NT").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"-cram", "--cache-ram"}, "N",
string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)"
std::string cls_sep = "\t"; // separator of classification sequences
// server params
- int32_t port = 8080; // server listens on this network port
- int32_t timeout_read = 600; // http read timeout in seconds
- int32_t timeout_write = timeout_read; // http write timeout in seconds
- int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
- int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
- bool cache_prompt = true; // whether to enable prompt caching
- int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot
- int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
+ int32_t port = 8080; // server listens on this network port
+ int32_t timeout_read = 600; // http read timeout in seconds
+ int32_t timeout_write = timeout_read; // http write timeout in seconds
+ int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
+ int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
+ bool cache_prompt = true; // whether to enable prompt caching
+ int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
+ int32_t checkpoint_every_nt = 8192; // make a checkpoint every n tokens during prefill
+ int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT
#include "mtmd.h"
#include "mtmd-helper.h"
+#include <algorithm>
#include <cstddef>
#include <cinttypes>
#include <memory>
const auto it = std::find_if(
slot.prompt.checkpoints.rbegin(),
slot.prompt.checkpoints.rend(),
- [&](const auto & cur) {
+ [&, func_name = __func__](const auto & cur) {
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
+ LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d] against %d...\n", 12,
+ func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, pos_min_thold);
return cur.pos_min < pos_min_thold;
}
);
slot.i_batch = batch.n_tokens - 1;
slot.init_sampler();
+ SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
+ } else {
+ // only do non-end checkpoints if the "checkpoint every n tokens" option is set
+ do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0;
+ if (do_checkpoint) {
+ llama_pos last_checkpoint = 0;
+ if (!slot.prompt.checkpoints.empty()) {
+ last_checkpoint = slot.prompt.checkpoints.back().n_tokens;
+ }
+ do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt;
+ if (do_checkpoint) {
+ SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens());
+ }
+ }
+ SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
+ }
- const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
- const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
+ const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
+ const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
- // no need for empty or small checkpoints
- do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
+ // no need for empty or small checkpoints
+ do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
- // no need to create checkpoints that are too close together
- do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
+ // no need to create checkpoints that are too close together
+ do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
- // note: we create the checkpoint before calling llama_decode(), so the current batch is not
- // yet processed and therefore it is not part of the checkpoint.
- if (do_checkpoint) {
- while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
- // make room for the new checkpoint, if needed
- const auto & cur = slot.prompt.checkpoints.front();
+ // note: we create the checkpoint before calling llama_decode(), so the current batch is not
+ // yet processed and therefore it is not part of the checkpoint.
+ if (do_checkpoint) {
+ while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
+ // make room for the new checkpoint, if needed
+ const auto & cur = slot.prompt.checkpoints.front();
- SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
- cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
+ SLT_WRN(slot,
+ "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64
+ ", size = %.3f MiB)\n",
+ cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
- slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
- }
-
- const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
+ slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
+ }
- auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
- /*.pos_min = */ pos_min,
- /*.pos_max = */ pos_max,
- /*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens,
- /*.data = */ std::vector<uint8_t>(checkpoint_size),
- });
+ const size_t checkpoint_size =
+ llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
- llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
+ auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
+ /*.pos_min = */ pos_min,
+ /*.pos_max = */ pos_max,
+ /*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens,
+ /*.data = */ std::vector<uint8_t>(checkpoint_size),
+ });
- SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
- (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
- }
+ llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id,
+ LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
- SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
- } else {
- SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
+ SLT_WRN(slot,
+ "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64
+ ", size = %.3f MiB)\n",
+ (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min,
+ cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
}
}