]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : add SWA checkpoints (#15293)
authorGeorgi Gerganov <redacted>
Thu, 14 Aug 2025 11:59:50 +0000 (14:59 +0300)
committerGitHub <redacted>
Thu, 14 Aug 2025 11:59:50 +0000 (14:59 +0300)
* server : add SWA checkpoints

ggml-ci

* cont : server clean-up

* server : handle state restore fails

* llama : add extended llama_state_seq_ API

* server : do not make checkpoints if --swa-full

ggml-ci

* llama : remove flags value for NONE

* server : configure number of SWA checkpoints with CLI arg

ggml-ci

* args : fix scope of new argument

15 files changed:
common/arg.cpp
common/common.h
include/llama.h
src/llama-context.cpp
src/llama-context.h
src/llama-kv-cache-unified-iswa.cpp
src/llama-kv-cache-unified-iswa.h
src/llama-kv-cache-unified.cpp
src/llama-kv-cache-unified.h
src/llama-memory-hybrid.cpp
src/llama-memory-hybrid.h
src/llama-memory-recurrent.cpp
src/llama-memory-recurrent.h
src/llama-memory.h
tools/server/server.cpp

index 4cbc4e3bf4e2bcf7b02be90c4dc403b743a6659b..98baac4c14da284f85f35df48d3b5b1c678b64ea 100644 (file)
@@ -1507,6 +1507,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.swa_full = true;
         }
     ).set_env("LLAMA_ARG_SWA_FULL"));
+    add_opt(common_arg(
+        {"--swa-checkpoints"}, "N",
+        string_format("max number of SWA checkpoints per slot to create (default: %d)\n"
+            "[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints),
+        [](common_params & params, int value) {
+            params.n_swa_checkpoints = value;
+        }
+    ).set_env("LLAMA_ARG_SWA_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
     add_opt(common_arg(
         {"--kv-unified", "-kvu"},
         string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
index edb79b4b4351abe998d3f252b60268c55dd41ab1..75596e6b329791788e4af2d00470d94759bc72b5 100644 (file)
@@ -413,11 +413,12 @@ struct common_params {
     std::string cls_sep    = "\t";  // separator of classification sequences
 
     // server params
-    int32_t port           = 8080;         // server listens on this network port
-    int32_t timeout_read   = 600;          // http read timeout in seconds
-    int32_t timeout_write  = timeout_read; // http write timeout in seconds
-    int32_t n_threads_http = -1;           // number of threads to process HTTP requests (TODO: support threadpool)
-    int32_t n_cache_reuse  = 0;            // min chunk size to reuse from the cache via KV shifting
+    int32_t port              = 8080;         // server listens on this network port
+    int32_t timeout_read      = 600;          // http read timeout in seconds
+    int32_t timeout_write     = timeout_read; // http write timeout in seconds
+    int32_t n_threads_http    = -1;           // number of threads to process HTTP requests (TODO: support threadpool)
+    int32_t n_cache_reuse     = 0;            // min chunk size to reuse from the cache via KV shifting
+    int32_t n_swa_checkpoints = 3;            // max number of SWA checkpoints per slot
 
     std::string hostname      = "127.0.0.1";
     std::string public_path   = "";                                                                         // NOLINT
index 0c866a6edc9d2a1f39ca2d3ed3718bbc4b791dac..135eaf1b655695e8e407dedd65ff2c6fe4db53dd 100644 (file)
@@ -870,6 +870,29 @@ extern "C" {
                           size_t   n_token_capacity,
                           size_t * n_token_count_out);
 
+#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1
+
+    typedef uint32_t llama_state_seq_flags;
+
+    LLAMA_API size_t llama_state_seq_get_size_ext(
+            struct llama_context * ctx,
+                    llama_seq_id   seq_id,
+           llama_state_seq_flags   flags);
+
+    LLAMA_API size_t llama_state_seq_get_data_ext(
+            struct llama_context * ctx,
+                         uint8_t * dst,
+                          size_t   size,
+                    llama_seq_id   seq_id,
+           llama_state_seq_flags   flags);
+
+    LLAMA_API size_t llama_state_seq_set_data_ext(
+            struct llama_context * ctx,
+                   const uint8_t * src,
+                          size_t   size,
+                    llama_seq_id   dest_seq_id,
+           llama_state_seq_flags   flags);
+
     //
     // Decoding
     //
index c240e1dfe5dbcddd27b9c8b5fb32c3830aebb8dc..7d7abad5d4a2dd94ce53038c1ecd5e6b2ed3b2a4 100644 (file)
@@ -1657,30 +1657,30 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
     }
 }
 
-size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
+size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) {
     llama_io_write_dummy io;
     try {
-        return state_seq_write_data(io, seq_id);
+        return state_seq_write_data(io, seq_id, flags);
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
         return 0;
     }
 }
 
-size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
+size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) {
     llama_io_write_buffer io(dst, size);
     try {
-        return state_seq_write_data(io, seq_id);
+        return state_seq_write_data(io, seq_id, flags);
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
         return 0;
     }
 }
 
-size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
+size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) {
     llama_io_read_buffer io(src, size);
     try {
-        return state_seq_read_data(io, seq_id);
+        return state_seq_read_data(io, seq_id, flags);
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
         return 0;
@@ -1778,7 +1778,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file
     {
         const size_t state_size = file.size() - file.tell();
         llama_io_read_file io(&file);
-        const size_t nread = state_seq_read_data(io, seq_id);
+        const size_t nread = state_seq_read_data(io, seq_id, 0);
         if (!nread) {
             LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
             return 0;
@@ -1802,7 +1802,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file
 
     // save the context state using stream saving
     llama_io_write_file io(&file);
-    state_seq_write_data(io, seq_id);
+    state_seq_write_data(io, seq_id, 0);
 
     const size_t res = file.tell();
     GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
@@ -1971,21 +1971,21 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
     return io.n_bytes();
 }
 
-size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
+size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
     GGML_UNUSED(seq_id);
 
     if (memory) {
-        memory->state_write(io, seq_id);
+        memory->state_write(io, seq_id, flags);
     }
 
     return io.n_bytes();
 }
 
-size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
+size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
     GGML_UNUSED(seq_id);
 
     if (memory) {
-        memory->state_read(io, seq_id);
+        memory->state_read(io, seq_id, flags);
     }
 
     return io.n_bytes();
@@ -2801,19 +2801,31 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const
 }
 
 size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
-    return ctx->state_seq_get_size(seq_id);
+    return llama_state_seq_get_size_ext(ctx, seq_id, 0);
 }
 
 size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
+    return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0);
+}
+
+size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
+    return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0);
+}
+
+size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) {
+    return ctx->state_seq_get_size(seq_id, flags);
+}
+
+size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
     ctx->synchronize();
 
-    return ctx->state_seq_get_data(seq_id, dst, size);
+    return ctx->state_seq_get_data(seq_id, dst, size, flags);
 }
 
-size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
+size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) {
     ctx->synchronize();
 
-    return ctx->state_seq_set_data(seq_id, src, size);
+    return ctx->state_seq_set_data(seq_id, src, size, flags);
 }
 
 size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
index 3c53ab625a6feaa282a97e6bbba3f4172edc18f8..230ef8962b8fa41113f0c38d94798a1774a05f35 100644 (file)
@@ -111,9 +111,9 @@ struct llama_context {
     size_t state_get_data(      uint8_t * dst, size_t size);
     size_t state_set_data(const uint8_t * src, size_t size);
 
-    size_t state_seq_get_size(llama_seq_id seq_id);
-    size_t state_seq_get_data(llama_seq_id seq_id,       uint8_t * dst, size_t size);
-    size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
+    size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags);
+    size_t state_seq_get_data(llama_seq_id seq_id,       uint8_t * dst, size_t size, llama_state_seq_flags flags);
+    size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags);
 
     bool state_load_file(
             const char * filepath,
@@ -213,8 +213,8 @@ private:
     size_t state_write_data(llama_io_write_i & io);
     size_t state_read_data (llama_io_read_i  & io);
 
-    size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
-    size_t state_seq_read_data (llama_io_read_i  & io, llama_seq_id seq_id);
+    size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
+    size_t state_seq_read_data (llama_io_read_i  & io, llama_seq_id seq_id, llama_state_seq_flags flags);
 
     //
     // members
index 01d27fb4db9b1d8adb104432b8c5c64f3b2ece7c..1e363fff2a554f8ef4cacccb6612b01c8685c61a 100644 (file)
@@ -194,14 +194,20 @@ bool llama_kv_cache_unified_iswa::get_can_shift() const {
     return kv_base->get_size() == kv_swa->get_size();
 }
 
-void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
-    kv_base->state_write(io, seq_id);
-    kv_swa ->state_write(io, seq_id);
+void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
+    if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
+        kv_base->state_write(io, seq_id, flags);
+    }
+
+    kv_swa->state_write(io, seq_id, flags);
 }
 
-void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
-    kv_base->state_read(io, seq_id);
-    kv_swa ->state_read(io, seq_id);
+void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
+    if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
+        kv_base->state_read(io, seq_id, flags);
+    }
+
+    kv_swa->state_read(io, seq_id, flags);
 }
 
 llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
index d2650dadd3595b2614551e36941bffda8f6909af..7bc4df718d3427e3fad0fd48431a995c67f43189 100644 (file)
@@ -56,8 +56,8 @@ public:
 
     // state write/load
 
-    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
-    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1)       override;
+    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
 
     //
     // llama_kv_cache_unified_iswa specific API
index 88c88552aaad0204e8114d31ff95728e51a9b1f1..478ebffac0f63f14138cda68ca9b2c21b1ca9229 100644 (file)
@@ -1828,7 +1828,9 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
     return false;
 }
 
-void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
+void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
+    GGML_UNUSED(flags);
+
     io.write(&n_stream, sizeof(n_stream));
 
     for (uint32_t s = 0; s < n_stream; ++s) {
@@ -1879,7 +1881,9 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
     }
 }
 
-void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
+void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
+    GGML_UNUSED(flags);
+
     GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
 
     uint32_t n_stream_cur;
index 342a675962e2a85668d882888d98d4e71cf115c4..07a7c9e4e46a1afbd5835a92c2d0b18dfdc31483 100644 (file)
@@ -136,8 +136,8 @@ public:
 
     // state write/load
 
-    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
-    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1)       override;
+    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
 
     //
     // llama_kv_cache_unified specific API
index e98b4e354695987d712ded4ce6974186800fad90..cbeeb21344ecede21046e05666b2b4863f1a4187 100644 (file)
@@ -165,12 +165,16 @@ llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
     return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
 }
 
-void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
+void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
+    GGML_UNUSED(flags);
+
     mem_attn->state_write(io, seq_id);
     mem_recr->state_write(io, seq_id);
 }
 
-void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
+void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
+    GGML_UNUSED(flags);
+
     mem_attn->state_read(io, seq_id);
     mem_recr->state_read(io, seq_id);
 }
index c2d56cd541594381f40524f18ab72263481aef05..acdbc26bfb624c1a7ac22ce55d4808c6e3398148 100644 (file)
@@ -74,8 +74,8 @@ public:
 
     // state write/load
 
-    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
-    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1)       override;
+    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0)       override;
 
     //
     // llama_memory_hybrid specific API
index c0c2ec084dc1447787e9c3e95aa64f572244eb0c..849675c418891d9da3436d5c2a6a035cb7b95ce9 100644 (file)
@@ -680,7 +680,9 @@ size_t llama_memory_recurrent::size_s_bytes() const {
     return size_s_bytes;
 }
 
-void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
+void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
+    GGML_UNUSED(flags);
+
     std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
     uint32_t cell_count = 0;
 
@@ -718,7 +720,9 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq
     state_write_data(io, cell_ranges);
 }
 
-void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
+void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
+    GGML_UNUSED(flags);
+
     uint32_t cell_count;
     io.read_to(&cell_count, sizeof(cell_count));
 
index 4d094f9a05788cb3fa18b29bc188de1926578225..95c617b2c94bdb53480d7129a0f1474e7cfbbfea 100644 (file)
@@ -63,8 +63,8 @@ public:
 
     // state write/load
 
-    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
-    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1) override;
+    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
 
     uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
     uint32_t size = 0; // total number of cells, shared across all sequences
index e8ba336e8525d16b2cd277eb53a60c4c36ecbc39..171d312cc99d91c3b0665322de6f8e59f2c29de0 100644 (file)
@@ -104,8 +104,8 @@ struct llama_memory_i {
     // state write/read
     //
 
-    virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
-    virtual void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1) = 0;
+    virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const = 0;
+    virtual void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0;
 };
 
 using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
index 20fc784a8217b50a1fa47b5c0532629d458bfa09..8578d49e0394b9fd98cc0fb92f56eff6739103ca 100644 (file)
@@ -692,6 +692,13 @@ struct completion_token_output {
     }
 };
 
+struct swa_checkpoint {
+    llama_pos pos_min;
+    llama_pos pos_max;
+
+    std::vector<uint8_t> data;
+};
+
 struct server_task_result_cmpl_final : server_task_result {
     int index = 0;
 
@@ -1336,6 +1343,8 @@ struct server_slot {
 
     std::vector<completion_token_output> generated_token_probs;
 
+    std::vector<swa_checkpoint> swa_checkpoints;
+
     bool has_next_token = true;
     bool has_new_line   = false;
     bool truncated      = false;
@@ -3293,6 +3302,8 @@ struct server_context {
                                 slot.n_past = 0;
                             }
 
+                            const auto n_swa = llama_model_n_swa(model);
+
                             if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
                                 const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
                                 if (pos_min == -1) {
@@ -3300,12 +3311,58 @@ struct server_context {
                                     GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
                                 }
 
-                                const auto n_swa = llama_model_n_swa(model);
-                                if (pos_min > std::max(0, slot.n_past - n_swa)) {
+                                const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
+
+                                if (pos_min > pos_min_thold) {
                                     SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
-                                    SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
-                                            "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
-                                    slot.n_past = 0;
+
+                                    // search for a SWA checkpoint
+                                    const auto it = std::find_if(
+                                        slot.swa_checkpoints.rbegin(),
+                                        slot.swa_checkpoints.rend(),
+                                        [&](const auto & cur) {
+                                            return cur.pos_min <= pos_min_thold;
+                                        }
+                                    );
+
+                                    bool do_reset = it == slot.swa_checkpoints.rend();
+
+                                    if (!do_reset) {
+                                        // restore the checkpoint
+                                        const size_t swa_size = it->data.size();
+                                        const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
+
+                                        if (n != swa_size) {
+                                            SLT_ERR(slot, "failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
+                                            do_reset = true;
+                                        } else {
+                                            slot.n_past = std::min(slot.n_past, it->pos_max);
+
+                                            SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
+                                        }
+                                    }
+
+                                    if (do_reset) {
+                                        SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
+                                                "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
+
+                                        slot.n_past = 0;
+                                        slot.swa_checkpoints.clear();
+                                    }
+                                }
+                            }
+
+                            if (n_swa > 0) {
+                                const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
+
+                                // erase any checkpoints with pos_min > pos_min_thold
+                                for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) {
+                                    const auto & cur = slot.swa_checkpoints[i];
+                                    if (cur.pos_min > pos_min_thold) {
+                                        slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i);
+
+                                        SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
+                                    }
                                 }
                             }
                         }
@@ -3519,6 +3576,39 @@ struct server_context {
 
                     // prompt evaluated for next-token prediction
                     slot.state = SLOT_STATE_GENERATING;
+
+                    // make a checkpoint with the SWA memory
+                    // checkpoints are needed only if we are not using "--swa-full"
+                    if (llama_model_n_swa(model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0) {
+                        if (slot.swa_checkpoints.size() >= (size_t) params_base.n_swa_checkpoints) {
+                            {
+                                const auto & cur = slot.swa_checkpoints.back();
+
+                                SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n",
+                                        cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
+                            }
+
+                            slot.swa_checkpoints.erase(slot.swa_checkpoints.begin());
+                        }
+
+                        const size_t swa_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
+
+                        auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{
+                            /*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
+                            /*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
+                            /*.data    = */ std::vector<uint8_t>(swa_size),
+                        });
+
+                        llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
+
+                        float size_total = 0.0f;
+                        for (const auto & checkpoint : slot.swa_checkpoints) {
+                            size_total += (float) checkpoint.data.size() / 1024 / 1024;
+                        }
+
+                        SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d (%.3f MiB)\n",
+                                cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024, (int) slot.swa_checkpoints.size(), params_base.n_swa_checkpoints, size_total);
+                    }
                 } else if (slot.state != SLOT_STATE_GENERATING) {
                     continue; // continue loop of slots
                 }