params.swa_full = true;
}
).set_env("LLAMA_ARG_SWA_FULL"));
+ add_opt(common_arg(
+ {"--swa-checkpoints"}, "N",
+ string_format("max number of SWA checkpoints per slot to create (default: %d)\n"
+ "[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints),
+ [](common_params & params, int value) {
+ params.n_swa_checkpoints = value;
+ }
+ ).set_env("LLAMA_ARG_SWA_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--kv-unified", "-kvu"},
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
std::string cls_sep = "\t"; // separator of classification sequences
// server params
- int32_t port = 8080; // server listens on this network port
- int32_t timeout_read = 600; // http read timeout in seconds
- int32_t timeout_write = timeout_read; // http write timeout in seconds
- int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
- int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
+ int32_t port = 8080; // server listens on this network port
+ int32_t timeout_read = 600; // http read timeout in seconds
+ int32_t timeout_write = timeout_read; // http write timeout in seconds
+ int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
+ int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
+ int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot
std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT
size_t n_token_capacity,
size_t * n_token_count_out);
+#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1
+
+ typedef uint32_t llama_state_seq_flags;
+
+ LLAMA_API size_t llama_state_seq_get_size_ext(
+ struct llama_context * ctx,
+ llama_seq_id seq_id,
+ llama_state_seq_flags flags);
+
+ LLAMA_API size_t llama_state_seq_get_data_ext(
+ struct llama_context * ctx,
+ uint8_t * dst,
+ size_t size,
+ llama_seq_id seq_id,
+ llama_state_seq_flags flags);
+
+ LLAMA_API size_t llama_state_seq_set_data_ext(
+ struct llama_context * ctx,
+ const uint8_t * src,
+ size_t size,
+ llama_seq_id dest_seq_id,
+ llama_state_seq_flags flags);
+
//
// Decoding
//
}
}
-size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
+size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) {
llama_io_write_dummy io;
try {
- return state_seq_write_data(io, seq_id);
+ return state_seq_write_data(io, seq_id, flags);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
return 0;
}
}
-size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
+size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) {
llama_io_write_buffer io(dst, size);
try {
- return state_seq_write_data(io, seq_id);
+ return state_seq_write_data(io, seq_id, flags);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
return 0;
}
}
-size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
+size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) {
llama_io_read_buffer io(src, size);
try {
- return state_seq_read_data(io, seq_id);
+ return state_seq_read_data(io, seq_id, flags);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
return 0;
{
const size_t state_size = file.size() - file.tell();
llama_io_read_file io(&file);
- const size_t nread = state_seq_read_data(io, seq_id);
+ const size_t nread = state_seq_read_data(io, seq_id, 0);
if (!nread) {
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
return 0;
// save the context state using stream saving
llama_io_write_file io(&file);
- state_seq_write_data(io, seq_id);
+ state_seq_write_data(io, seq_id, 0);
const size_t res = file.tell();
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
return io.n_bytes();
}
-size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
+size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(seq_id);
if (memory) {
- memory->state_write(io, seq_id);
+ memory->state_write(io, seq_id, flags);
}
return io.n_bytes();
}
-size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
+size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(seq_id);
if (memory) {
- memory->state_read(io, seq_id);
+ memory->state_read(io, seq_id, flags);
}
return io.n_bytes();
}
size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
- return ctx->state_seq_get_size(seq_id);
+ return llama_state_seq_get_size_ext(ctx, seq_id, 0);
}
size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
+ return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0);
+}
+
+size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
+ return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0);
+}
+
+size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
+ return ctx->state_seq_get_size(seq_id, flags);
+}
+
+size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
ctx->synchronize();
- return ctx->state_seq_get_data(seq_id, dst, size);
+ return ctx->state_seq_get_data(seq_id, dst, size, flags);
}
-size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
+size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
ctx->synchronize();
- return ctx->state_seq_set_data(seq_id, src, size);
+ return ctx->state_seq_set_data(seq_id, src, size, flags);
}
size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
size_t state_get_data( uint8_t * dst, size_t size);
size_t state_set_data(const uint8_t * src, size_t size);
- size_t state_seq_get_size(llama_seq_id seq_id);
- size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
- size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
+ size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags);
+ size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags);
+ size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags);
bool state_load_file(
const char * filepath,
size_t state_write_data(llama_io_write_i & io);
size_t state_read_data (llama_io_read_i & io);
- size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
- size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id);
+ size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
+ size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
//
// members
return kv_base->get_size() == kv_swa->get_size();
}
-void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
- kv_base->state_write(io, seq_id);
- kv_swa ->state_write(io, seq_id);
+void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
+ if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
+ kv_base->state_write(io, seq_id, flags);
+ }
+
+ kv_swa->state_write(io, seq_id, flags);
}
-void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
- kv_base->state_read(io, seq_id);
- kv_swa ->state_read(io, seq_id);
+void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
+ if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
+ kv_base->state_read(io, seq_id, flags);
+ }
+
+ kv_swa->state_read(io, seq_id, flags);
}
llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
// state write/load
- void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
- void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
+ void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
//
// llama_kv_cache_unified_iswa specific API
return false;
}
-void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
+void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
+ GGML_UNUSED(flags);
+
io.write(&n_stream, sizeof(n_stream));
for (uint32_t s = 0; s < n_stream; ++s) {
}
}
-void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
+void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
+ GGML_UNUSED(flags);
+
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
uint32_t n_stream_cur;
// state write/load
- void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
- void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
+ void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
//
// llama_kv_cache_unified specific API
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
}
-void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
+void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
+ GGML_UNUSED(flags);
+
mem_attn->state_write(io, seq_id);
mem_recr->state_write(io, seq_id);
}
-void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
+void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
+ GGML_UNUSED(flags);
+
mem_attn->state_read(io, seq_id);
mem_recr->state_read(io, seq_id);
}
// state write/load
- void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
- void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
+ void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
//
// llama_memory_hybrid specific API
return size_s_bytes;
}
-void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
+void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
+ GGML_UNUSED(flags);
+
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
uint32_t cell_count = 0;
state_write_data(io, cell_ranges);
}
-void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
+void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
+ GGML_UNUSED(flags);
+
uint32_t cell_count;
io.read_to(&cell_count, sizeof(cell_count));
// state write/load
- void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
- void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
+ void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
+ void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
uint32_t size = 0; // total number of cells, shared across all sequences
// state write/read
//
- virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
- virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
+ virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const = 0;
+ virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0;
};
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
}
};
+struct swa_checkpoint {
+ llama_pos pos_min;
+ llama_pos pos_max;
+
+ std::vector<uint8_t> data;
+};
+
struct server_task_result_cmpl_final : server_task_result {
int index = 0;
std::vector<completion_token_output> generated_token_probs;
+ std::vector<swa_checkpoint> swa_checkpoints;
+
bool has_next_token = true;
bool has_new_line = false;
bool truncated = false;
slot.n_past = 0;
}
+ const auto n_swa = llama_model_n_swa(model);
+
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
if (pos_min == -1) {
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
}
- const auto n_swa = llama_model_n_swa(model);
- if (pos_min > std::max(0, slot.n_past - n_swa)) {
+ const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
+
+ if (pos_min > pos_min_thold) {
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
- SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
- "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
- slot.n_past = 0;
+
+ // search for a SWA checkpoint
+ const auto it = std::find_if(
+ slot.swa_checkpoints.rbegin(),
+ slot.swa_checkpoints.rend(),
+ [&](const auto & cur) {
+ return cur.pos_min <= pos_min_thold;
+ }
+ );
+
+ bool do_reset = it == slot.swa_checkpoints.rend();
+
+ if (!do_reset) {
+ // restore the checkpoint
+ const size_t swa_size = it->data.size();
+ const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
+
+ if (n != swa_size) {
+ SLT_ERR(slot, "failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
+ do_reset = true;
+ } else {
+ slot.n_past = std::min(slot.n_past, it->pos_max);
+
+ SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
+ }
+ }
+
+ if (do_reset) {
+ SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
+ "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
+
+ slot.n_past = 0;
+ slot.swa_checkpoints.clear();
+ }
+ }
+ }
+
+ if (n_swa > 0) {
+ const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
+
+ // erase any checkpoints with pos_min > pos_min_thold
+ for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) {
+ const auto & cur = slot.swa_checkpoints[i];
+ if (cur.pos_min > pos_min_thold) {
+ slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i);
+
+ SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
+ }
}
}
}
// prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING;
+
+ // make a checkpoint with the SWA memory
+ // checkpoints are needed only if we are not using "--swa-full"
+ if (llama_model_n_swa(model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0) {
+ if (slot.swa_checkpoints.size() >= (size_t) params_base.n_swa_checkpoints) {
+ {
+ const auto & cur = slot.swa_checkpoints.back();
+
+ SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n",
+ cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
+ }
+
+ slot.swa_checkpoints.erase(slot.swa_checkpoints.begin());
+ }
+
+ const size_t swa_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
+
+ auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{
+ /*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
+ /*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
+ /*.data = */ std::vector<uint8_t>(swa_size),
+ });
+
+ llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
+
+ float size_total = 0.0f;
+ for (const auto & checkpoint : slot.swa_checkpoints) {
+ size_total += (float) checkpoint.data.size() / 1024 / 1024;
+ }
+
+ SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d (%.3f MiB)\n",
+ cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024, (int) slot.swa_checkpoints.size(), params_base.n_swa_checkpoints, size_total);
+ }
} else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots
}