res->id = slot.task->id;
res->id_slot = slot.id;
- res->index = slot.task->index;
+ res->index = slot.task->index;
// keep copy of last generated text for debugging purposes
if (slots_debug) {
n_past = 0;
}
+ llama_pos pos_next = slot.prompt.tokens.pos_next(n_past);
+
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
const auto n_swa = std::max(1, llama_model_n_swa(model));
// the largest pos_min required for a checkpoint to be useful
- const auto pos_min_thold = std::max(0, n_past - n_swa);
+ const auto pos_min_thold = std::max(0, pos_next - n_swa);
- // note: disallow with mtmd contexts for now
- // https://github.com/ggml-org/llama.cpp/issues/17043
- if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) {
+ if (n_past > 0 && n_past < slot.prompt.n_tokens()) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
if (pos_min == -1) {
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
}
if (pos_min > pos_min_thold) {
- // TODO: support can be added in the future when corresponding vision models get released
- GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
-
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
// search for a context checkpoint
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != checkpoint_size) {
- SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
+ SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024);
do_reset = true;
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
} else {
- n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
- SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
+ pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
+ n_past = slot.prompt.tokens.size_up_to_pos(pos_next);
+ SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024);
}
}
if (do_reset) {
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
+ pos_next = 0;
n_past = 0;
}
}
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
const auto & cur = *it;
if (cur.pos_min > pos_min_thold) {
- SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
+ SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, (float) cur.data.size() / 1024 / 1024);
it = slot.prompt.checkpoints.erase(it);
} else {
++it;
SLT_WRN(slot, "n_past was set to %d\n", n_past);
}
- slot.n_prompt_tokens_cache = n_past;
+ slot.n_prompt_tokens_cache = n_past;
slot.n_prompt_tokens_processed = 0;
slot.prompt.tokens.keep_first(n_past);
}
}
- // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
-
- SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
-
// entire prompt has been processed
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
slot.state = SLOT_STATE_DONE_PROMPT;
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
- SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
-
slot.init_sampler();
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
+ // note: we create the checkpoint before calling llama_decode(), so the current batch is not
+ // yet processed and therefore it is not part of the checkpoint.
if (do_checkpoint) {
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.prompt.checkpoints.front();
- SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
- cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
+ SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
+ cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
- /*.pos_min = */ pos_min,
- /*.pos_max = */ pos_max,
- /*.data = */ std::vector<uint8_t>(checkpoint_size),
+ /*.pos_min = */ pos_min,
+ /*.pos_max = */ pos_max,
+ /*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens,
+ /*.data = */ std::vector<uint8_t>(checkpoint_size),
});
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
- SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
- (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
+ SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
+ (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
}
+
+ SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
+ } else {
+ SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
}
}