alora_disabled_id = enabled_loras[0];
}
+ bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
+
+ // make a checkpoint of the parts of the memory that cannot be rolled back.
+ // checkpoints are created only if:
+ // - the model uses SWA and we are not using `swa_full`
+ // - the model architecture is marked as recurrent or hybrid
+ //
+ // TODO: try to make this conditional on the context or the memory module, instead of the model type
+ do_checkpoint = do_checkpoint && (
+ llama_model_is_recurrent(model) ||
+ llama_model_is_hybrid(model) ||
+ (llama_model_n_swa(model) > 0 && !params_base.swa_full)
+ );
+
// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
// get next token to process
slot.n_prompt_tokens_processed++;
slot.n_past++;
+
+ // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
+ if (do_checkpoint && slot.n_prompt_tokens - slot.n_past == 64) {
+ break;
+ }
}
// SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
slot.i_batch = batch.n_tokens - 1;
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
+
+ const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
+ const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
+
+ // no need for empty or small checkpoints
+ do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
+
+ // no need to create checkpoints that are too close together
+ do_checkpoint = do_checkpoint && (slot.ctx_checkpoints.empty() || pos_max > slot.ctx_checkpoints.back().pos_max + 64);
+
+ if (do_checkpoint) {
+ while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
+ // make room for the new checkpoint, if needed
+ const auto & cur = slot.ctx_checkpoints.front();
+ SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
+ cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
+
+ slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
+ }
+
+ const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
+
+ auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
+ /*.pos_min = */ pos_min,
+ /*.pos_max = */ pos_max,
+ /*.data = */ std::vector<uint8_t>(checkpoint_size),
+ });
+
+ llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
+
+ SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
+ (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
+ }
}
}
// prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING;
-
- // make a checkpoint of the parts of the memory that cannot be rolled back.
- // checkpoints are created only if:
- // - the model uses SWA and we are not using `swa_full`
- // - the model architecture is marked as recurrent or hybrid
- //
- // TODO: try to make this conditional on the context or the memory module, instead of the model type
- const bool do_checkpoint =
- (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) ||
- (llama_model_n_swa(model) > 0 && !params_base.swa_full);
-
- if (do_checkpoint && params_base.n_ctx_checkpoints > 0) {
- while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
- // make room for the new checkpoint, if needed
- const auto & cur = slot.ctx_checkpoints.front();
- SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
- cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
-
- slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
- }
-
- const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
-
- auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
- /*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
- /*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
- /*.data = */ std::vector<uint8_t>(checkpoint_size),
- });
-
- llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
-
- SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
- (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
- }
} else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots
}