]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : refactor session file management (#8699)
authorcompilade <redacted>
Sun, 28 Jul 2024 04:42:05 +0000 (00:42 -0400)
committerGitHub <redacted>
Sun, 28 Jul 2024 04:42:05 +0000 (00:42 -0400)
* llama : refactor session file management

* llama : saving and restoring state checks for overflow

The size of the buffers should now be given to the functions working
with them, otherwise a truncated file could cause out of bound reads.

* llama : stream from session file instead of copying into a big buffer

Loading session files should no longer cause a memory usage spike.

* llama : llama_state_get_size returns the actual size instead of max

This is a breaking change, but makes that function *much* easier
to keep up to date, and it also makes it reflect the behavior
of llama_state_seq_get_size.

* llama : share code between whole and seq_id-specific state saving

Both session file types now use a more similar format.

* llama : no longer store all hparams in session files

Instead, the model arch name is stored.
The layer count and the embedding dimensions of the KV cache
are still verified when loading.
Storing all the hparams is not necessary.

* llama : fix uint64_t format type

* llama : various integer type cast and format string fixes

Some platforms use "%lu" and others "%llu" for uint64_t.
Not sure how to handle that, so casting to size_t when displaying errors.

* llama : remove _context suffix for llama_data_context

* llama : fix session file loading

llama_state_get_size cannot be used to get the max size anymore.

* llama : more graceful error handling of invalid session files

* llama : remove LLAMA_MAX_RNG_STATE

It's no longer necessary to limit the size of the RNG state,
because the max size of session files is not estimated anymore.

* llama : cast seq_id in comparison with unsigned n_seq_max

examples/save-load-state/save-load-state.cpp
include/llama.h
src/llama.cpp

index 00c2277ac28279525a3d1f8b928b8cc49a6fba26..d8afdc141a4a4041bbedf295a96ead51b23c0a7b 100644 (file)
@@ -47,7 +47,7 @@ int main(int argc, char ** argv) {
     // save state (rng, logits, embedding and kv_cache) to file
     {
         std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
-        const size_t written = llama_state_get_data(ctx, state_mem.data());
+        const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());
 
         FILE *fp_write = fopen("dump_state.bin", "wb");
         fwrite(state_mem.data(), 1, written, fp_write);
@@ -99,13 +99,16 @@ int main(int argc, char ** argv) {
 
     // load state (rng, logits, embedding and kv_cache) from file
     {
-        std::vector<uint8_t> state_mem(llama_state_get_size(ctx2));
+        std::vector<uint8_t> state_mem;
 
         FILE * fp_read = fopen("dump_state.bin", "rb");
+        fseek(fp_read, 0, SEEK_END);
+        state_mem.resize(ftell(fp_read));
+        fseek(fp_read, 0, SEEK_SET);
         const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
         fclose(fp_read);
 
-        if (read != llama_state_set_data(ctx2, state_mem.data())) {
+        if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) {
             fprintf(stderr, "\n%s : failed to read state\n", __func__);
             llama_free(ctx2);
             llama_free_model(model);
@@ -159,13 +162,16 @@ int main(int argc, char ** argv) {
 
     // load state (rng, logits, embedding and kv_cache) from file
     {
-        std::vector<uint8_t> state_mem(llama_state_get_size(ctx3));
+        std::vector<uint8_t> state_mem;
 
         FILE * fp_read = fopen("dump_state.bin", "rb");
+        fseek(fp_read, 0, SEEK_END);
+        state_mem.resize(ftell(fp_read));
+        fseek(fp_read, 0, SEEK_SET);
         const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
         fclose(fp_read);
 
-        if (read != llama_state_set_data(ctx3, state_mem.data())) {
+        if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) {
             fprintf(stderr, "\n%s : failed to read state\n", __func__);
             llama_free(ctx3);
             llama_free_model(model);
@@ -182,7 +188,7 @@ int main(int argc, char ** argv) {
     {
         // save kv of seq 0
         std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
-        const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
+        const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0);
         if (ncopy != seq_store.size()) {
             fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
             llama_free(ctx3);
@@ -196,7 +202,7 @@ int main(int argc, char ** argv) {
         fprintf(stderr, "%s : kv cache cleared\n", __func__);
 
         // restore kv into seq 1
-        const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
+        const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1);
         if (nset != seq_store.size()) {
             fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
             llama_free(ctx3);
index 413070d95a5c4ef240b73131dafbc20d85d30341..f23355a6bc9593e53d5978ff81b381182dcccd61 100644 (file)
 
 #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
 
-#define LLAMA_MAX_RNG_STATE (64*1024)
-
 #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
 #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
 #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
 
 #define LLAMA_SESSION_MAGIC   LLAMA_FILE_MAGIC_GGSN
-#define LLAMA_SESSION_VERSION 7
+#define LLAMA_SESSION_VERSION 8
 
 #define LLAMA_STATE_SEQ_MAGIC   LLAMA_FILE_MAGIC_GGSQ
-#define LLAMA_STATE_SEQ_VERSION 1
+#define LLAMA_STATE_SEQ_VERSION 2
 
 #ifdef __cplusplus
 extern "C" {
@@ -691,10 +689,11 @@ extern "C" {
     // State / sessions
     //
 
-    // Returns the maximum size in bytes of the state (rng, logits, embedding
-    // and kv_cache) - will often be smaller after compacting tokens
-    LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
-    LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx),
+    // Returns the *actual* size in bytes of the state
+    // (rng, logits, embedding and kv_cache)
+    // Only use when saving the state, not when restoring it, otherwise the size may be too small.
+    LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
+    LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
         "use llama_state_get_size instead");
 
     // Copies the state to the specified destination address.
@@ -702,7 +701,8 @@ extern "C" {
     // Returns the number of bytes copied
     LLAMA_API size_t llama_state_get_data(
             struct llama_context * ctx,
-                         uint8_t * dst);
+                         uint8_t * dst,
+                          size_t   size);
     LLAMA_API DEPRECATED(size_t llama_copy_state_data(
             struct llama_context * ctx,
                          uint8_t * dst),
@@ -712,7 +712,8 @@ extern "C" {
     // Returns the number of bytes read
     LLAMA_API size_t llama_state_set_data(
             struct llama_context * ctx,
-                   const uint8_t * src);
+                   const uint8_t * src,
+                          size_t   size);
     LLAMA_API DEPRECATED(size_t llama_set_state_data(
             struct llama_context * ctx,
                    const uint8_t * src),
@@ -754,6 +755,7 @@ extern "C" {
     LLAMA_API size_t llama_state_seq_get_data(
             struct llama_context * ctx,
                          uint8_t * dst,
+                          size_t   size,
                     llama_seq_id   seq_id);
 
     // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
@@ -763,6 +765,7 @@ extern "C" {
     LLAMA_API size_t llama_state_seq_set_data(
             struct llama_context * ctx,
                    const uint8_t * src,
+                          size_t   size,
                     llama_seq_id   dest_seq_id);
 
     LLAMA_API size_t llama_state_seq_save_file(
index 0345d0062233e4c41e0c0e3353aa032451cd9e1d..a207451f585071016d2571b429989984eec0ea4c 100644 (file)
@@ -2933,7 +2933,7 @@ static bool llama_kv_cache_init(
 
     // TODO: find a nicer way to add other recurrent model architectures
     cache.recurrent = model.arch == LLM_ARCH_MAMBA;
-    cache.v_trans   = !cparams.flash_attn;
+    cache.v_trans   = !cache.recurrent && !cparams.flash_attn;
 
     cache.head = 0;
     cache.size = kv_size;
@@ -17303,18 +17303,18 @@ void llama_kv_cache_update(struct llama_context * ctx) {
 }
 
 // deprecated
-size_t llama_get_state_size(const struct llama_context * ctx) {
+size_t llama_get_state_size(struct llama_context * ctx) {
     return llama_state_get_size(ctx);
 }
 
 // deprecated
 size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
-    return llama_state_get_data(ctx, dst);
+    return llama_state_get_data(ctx, dst, -1);
 }
 
 // deprecated
 size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
-    return llama_state_set_data(ctx, src);
+    return llama_state_set_data(ctx, src, -1);
 }
 
 // deprecated
@@ -17327,302 +17327,284 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
     return llama_state_save_file(ctx, path_session, tokens, n_token_count);
 }
 
-// Returns the *maximum* size of the state
-size_t llama_state_get_size(const struct llama_context * ctx) {
-    const auto & cparams = ctx->cparams;
-    const auto & hparams = ctx->model.hparams;
-
-    // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
-    // for reference, std::mt19937(1337) serializes to 6701 bytes.
-    const size_t s_rng_size        = sizeof(size_t);
-    const size_t s_rng             = LLAMA_MAX_RNG_STATE;
-    const size_t s_n_outputs       = sizeof(size_t);
-    // assume worst case for outputs although only currently set ones are serialized
-    const size_t s_output_pos      = ctx->cparams.n_batch * sizeof(int32_t);
-    const size_t s_logits_size     = sizeof(size_t);
-    const size_t s_logits          = ctx->logits_size ? cparams.n_batch * hparams.n_vocab * sizeof(float) : 0;
-    const size_t s_embedding_size  = sizeof(size_t);
-    const size_t s_embedding       = ctx->embd_size   ? cparams.n_batch * hparams.n_embd  * sizeof(float) : 0;
-    const size_t s_kv_buf_size     = sizeof(size_t);
-    const size_t s_kv_head         = sizeof(uint32_t);
-    const size_t s_kv_size         = sizeof(uint32_t);
-    const size_t s_kv_used         = sizeof(uint32_t);
-    const size_t s_v_trans         = sizeof(uint32_t);
-    const size_t s_kv              = ctx->kv_self.total_size();
-    const size_t s_kv_cell         = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id);
-    const size_t s_kv_cells        = ctx->kv_self.size * s_kv_cell;
-
-    const size_t s_total = (
-        + s_rng_size
-        + s_rng
-        + s_n_outputs
-        + s_output_pos
-        + s_logits_size
-        + s_logits
-        + s_embedding_size
-        + s_embedding
-        + s_kv_buf_size
-        + s_kv_head
-        + s_kv_size
-        + s_kv_used
-        + s_v_trans
-        + s_kv
-        + s_kv_cells
-    );
-
-    // on session change it is very likely that the state size has changed - so we need to update this function
-    static_assert(LLAMA_SESSION_VERSION == 7, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?");
-
-    return s_total;
-}
-
-// llama_context_data
-struct llama_data_context {
+// TODO: replace all non-fatal assertions with returned errors or exceptions
+struct llama_data_write {
     virtual void write(const void * src, size_t size) = 0;
     virtual size_t get_size_written() = 0;
-    virtual ~llama_data_context() = default;
-};
+    virtual ~llama_data_write() = default;
 
-struct llama_data_buffer_context : llama_data_context {
-    uint8_t * ptr;
-    size_t size_written = 0;
+    void write_string(const std::string & str) {
+        uint32_t str_size = str.size();
 
-    llama_data_buffer_context(uint8_t * p) : ptr(p) {}
-
-    void write(const void * src, size_t size) override {
-        memcpy(ptr, src, size);
-        ptr += size;
-        size_written += size;
+        write(&str_size,  sizeof(str_size));
+        write(str.data(), str_size);
     }
 
-    size_t get_size_written() override {
-        return size_written;
+    void write_model_info(const struct llama_context * ctx) {
+        std::string arch_str = LLM_ARCH_NAMES.at(ctx->model.arch);
+        write_string(arch_str);
+        // TODO: add more model-specific info which should prevent loading the session file if not identical
     }
-};
 
-struct llama_data_file_context : llama_data_context {
-    llama_file * file;
-    size_t size_written = 0;
+    void write_rng(const std::mt19937 & rng) {
+        std::ostringstream rng_ss;
+        rng_ss << rng;
 
-    llama_data_file_context(llama_file * f) : file(f) {}
+        const std::string & rng_str = rng_ss.str();
 
-    void write(const void * src, size_t size) override {
-        file->write_raw(src, size);
-        size_written += size;
+        write_string(rng_str);
     }
 
-    size_t get_size_written() override {
-        return size_written;
-    }
-};
+    void write_output_ids(const struct llama_context * ctx) {
+        const uint32_t n_outputs = ctx->n_outputs;
 
-/** copy state data into either a buffer or file depending on the passed in context
- *
- * file context:
- * llama_file file("/path", "wb");
- * llama_data_file_context data_ctx(&file);
- * llama_state_get_data(ctx, &data_ctx);
- *
- * buffer context:
- * std::vector<uint8_t> buf(max_size, 0);
- * llama_data_buffer_context data_ctx(&buf.data());
- * llama_state_get_data(ctx, &data_ctx);
- *
-*/
-static void llama_state_get_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
-    llama_synchronize(ctx);
+        std::vector<int32_t> output_pos;
 
-    // copy rng
-    {
-        std::ostringstream rng_ss;
-        rng_ss << ctx->sampling.rng;
+        const size_t    n_batch = ctx->cparams.n_batch;
+        const auto & output_ids = ctx->output_ids;
+
+        GGML_ASSERT(n_outputs <= ctx->output_size);
+
+        output_pos.resize(n_outputs);
+
+        // build a more compact representation of the output ids
+        for (size_t i = 0; i < n_batch; ++i) {
+            // map an output id to a position in the batch
+            int32_t pos = output_ids[i];
+            if (pos >= 0) {
+                GGML_ASSERT((uint32_t) pos < n_outputs);
+                output_pos[pos] = i;
+            }
+        }
 
-        const std::string & rng_str  = rng_ss.str();
-        const size_t        rng_size = rng_str.size();
+        write(&n_outputs, sizeof(n_outputs));
 
-        GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
+        if (n_outputs) {
+            write(output_pos.data(), n_outputs * sizeof(int32_t));
+        }
+    }
 
-        data_ctx->write(&rng_size,      sizeof(rng_size));
-        data_ctx->write(rng_str.data(), rng_size);
+    void write_logits(const struct llama_context * ctx) {
+        const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab);
+
+        write(&logits_size, sizeof(logits_size));
+
+        if (logits_size) {
+            write(ctx->logits, logits_size * sizeof(float));
+        }
     }
 
-    // copy outputs
-    {
-        // Can't use ctx->n_outputs because it's not for the
-        // entire last batch when n_ubatch is smaller than n_batch
-        size_t n_outputs = 0;
+    void write_embeddings(const struct llama_context * ctx) {
+        const uint64_t embeddings_size = std::min((uint64_t) ctx->embd_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_embd);
 
-        // copy output ids
-        {
-            std::vector<int32_t> output_pos;
+        write(&embeddings_size, sizeof(embeddings_size));
+
+        if (embeddings_size) {
+            write(ctx->embd, embeddings_size * sizeof(float));
+        }
+    }
+
+    void write_kv_cache_meta(const llama_kv_cache & kv_self, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {
 
-            const size_t    n_batch = ctx->cparams.n_batch;
-            const auto & output_ids = ctx->output_ids;
+        for (const auto & range : cell_ranges) {
+            for (uint32_t i = range.first; i < range.second; ++i) {
+                const auto & cell = kv_self.cells[i];
+                const llama_pos pos      = cell.pos;
+                const uint32_t  n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
 
-            output_pos.resize(ctx->output_size);
+                write(&pos,      sizeof(pos));
+                write(&n_seq_id, sizeof(n_seq_id));
 
-            // build a more compact representation of the output ids
-            for (size_t i = 0; i < n_batch; ++i) {
-                // map an output id to a position in the batch
-                int32_t pos = output_ids[i];
-                if (pos >= 0) {
-                    if ((size_t) pos >= n_outputs) {
-                        n_outputs = pos + 1;
+                if (n_seq_id) {
+                    for (auto seq_id : cell.seq_id) {
+                        write(&seq_id, sizeof(seq_id));
                     }
-                    GGML_ASSERT((size_t) pos < ctx->output_size);
-                    output_pos[pos] = i;
                 }
             }
+        }
+    }
 
-            data_ctx->write(&n_outputs, sizeof(n_outputs));
+    void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) {
+        const struct llama_kv_cache & kv_self = ctx->kv_self;
+        const struct llama_hparams & hparams = ctx->model.hparams;
 
-            if (n_outputs) {
-                data_ctx->write(output_pos.data(), n_outputs * sizeof(int32_t));
-            }
-        }
+        const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
+        const uint32_t n_layer = hparams.n_layer;
 
-        // copy logits
-        {
-            const size_t logits_size = std::min(ctx->logits_size, n_outputs * ctx->model.hparams.n_vocab);
+        write(&v_trans, sizeof(v_trans));
+        write(&n_layer, sizeof(n_layer));
 
-            data_ctx->write(&logits_size, sizeof(logits_size));
+        std::vector<uint8_t> tmp_buf;
 
-            if (logits_size) {
-                data_ctx->write(ctx->logits, logits_size * sizeof(float));
-            }
-        }
+        // Iterate and write all the keys first, each row is a cell
+        // Get whole range at a time
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
 
-        // copy embeddings
-        {
-            const size_t embeddings_size = std::min(ctx->embd_size, n_outputs * ctx->model.hparams.n_embd);
+            // Write key type
+            const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
+            write(&k_type_i, sizeof(k_type_i));
 
-            data_ctx->write(&embeddings_size, sizeof(embeddings_size));
+            // Write row size of key
+            const uint64_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
+            write(&k_size_row, sizeof(k_size_row));
 
-            if (embeddings_size) {
-                data_ctx->write(ctx->embd, embeddings_size * sizeof(float));
+            // Read each range of cells of k_size length each into tmp_buf and write out
+            for (const auto & range : cell_ranges) {
+                const size_t range_size = range.second - range.first;
+                tmp_buf.resize(range_size * k_size_row);
+                ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), range.first * k_size_row, range_size * k_size_row);
+                write(tmp_buf.data(), tmp_buf.size());
             }
         }
-    }
 
-    // copy kv cache
-    {
-        const auto & kv_self = ctx->kv_self;
-        const auto & hparams = ctx->model.hparams;
-
-        const uint32_t n_layer      = hparams.n_layer;
-
-        // NOTE: kv_size and kv_buf_size are mostly used for sanity checks
-        const uint32_t kv_head     = llama_kv_cache_cell_max(kv_self);
-        const uint32_t kv_size     = kv_self.size;
-        const size_t   kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head;
-        const uint32_t kv_used     = kv_self.used;
-        const uint32_t v_trans     = kv_self.v_trans ? 1 : 0;
-
-        data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
-        data_ctx->write(&kv_head,     sizeof(kv_head));
-        data_ctx->write(&kv_size,     sizeof(kv_size));
-        data_ctx->write(&kv_used,     sizeof(kv_used));
-        data_ctx->write(&v_trans,     sizeof(v_trans));
-
-        if (kv_buf_size) {
-            const size_t pre_kv_buf_size = data_ctx->get_size_written();
-
-            std::vector<uint8_t> tmp_buf;
-            for (int il = 0; il < (int) n_layer; ++il) {
-                const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+        if (!kv_self.v_trans) {
+            for (uint32_t il = 0; il < n_layer; ++il) {
                 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 
-                const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
+                // Write value type
+                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+                write(&v_type_i, sizeof(v_type_i));
 
-                tmp_buf.resize(k_size);
-                ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size());
-                data_ctx->write(tmp_buf.data(), tmp_buf.size());
+                // Write row size of value
+                const uint64_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
+                write(&v_size_row, sizeof(v_size_row));
 
-                if (kv_self.recurrent || !kv_self.v_trans) {
-                    // v is contiguous for recurrent models
-                    // TODO: use other tensors for state models than k and v
-                    const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
-
-                    tmp_buf.resize(v_size);
-                    ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), 0, tmp_buf.size());
-                    data_ctx->write(tmp_buf.data(), tmp_buf.size());
-                    continue;
+                // Read each range of cells of v_size length each into tmp_buf and write out
+                for (const auto & range : cell_ranges) {
+                    const size_t range_size = range.second - range.first;
+                    tmp_buf.resize(range_size * v_size_row);
+                    ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row);
+                    write(tmp_buf.data(), tmp_buf.size());
                 }
+            }
+        } else {
+            // When v is transposed, we also need the element size and get the element ranges from each row
+            const uint32_t kv_size = kv_self.size;
+            for (uint32_t il = 0; il < n_layer; ++il) {
+                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 
-                // v is not contiguous, copy row by row
-                const size_t v_row_size   = ggml_row_size(kv_self.v_l[il]->type, kv_head);
-                const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
+                // Write value type
+                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+                write(&v_type_i, sizeof(v_type_i));
 
-                tmp_buf.resize(v_row_size);
-                for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
-                    ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), ir*v_row_stride, tmp_buf.size());
-                    data_ctx->write(tmp_buf.data(), tmp_buf.size());
+                // Write element size
+                const uint32_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
+                write(&v_size_el, sizeof(v_size_el));
+
+                // Write GQA embedding size
+                write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
+
+                // For each row, we get the element values of each cell
+                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                    // Read each range of cells of v_size_el length each into tmp_buf and write out
+                    for (const auto & range : cell_ranges) {
+                        const size_t range_size = range.second - range.first;
+                        const size_t src_offset = (range.first + j * kv_size) * v_size_el;
+                        tmp_buf.resize(range_size * v_size_el);
+                        ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
+                        write(tmp_buf.data(), tmp_buf.size());
+                    }
                 }
             }
-            GGML_ASSERT(kv_buf_size == data_ctx->get_size_written() - pre_kv_buf_size);
         }
+    }
 
-        for (uint32_t i = 0; i < kv_head; ++i) {
-            const auto & cell = kv_self.cells[i];
-
-            const llama_pos pos         = cell.pos;
-            const size_t    seq_id_size = cell.seq_id.size();
-
-            data_ctx->write(&pos,         sizeof(pos));
-            data_ctx->write(&seq_id_size, sizeof(seq_id_size));
+    void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
+        const struct llama_kv_cache & kv_self = ctx->kv_self;
+        std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
+        uint32_t cell_count = 0;
 
-            for (auto seq_id : cell.seq_id) {
-                data_ctx->write(&seq_id, sizeof(seq_id));
+        // Count the number of cells with the specified seq_id
+        // Find all the ranges of cells with this seq id (or all, when -1)
+        uint32_t cell_range_begin = kv_self.size;
+        for (uint32_t i = 0; i < kv_self.size; ++i) {
+            const auto & cell = kv_self.cells[i];
+            if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
+                ++cell_count;
+                if (cell_range_begin == kv_self.size) {
+                    cell_range_begin = i;
+                }
+            } else {
+                if (cell_range_begin != kv_self.size) {
+                    cell_ranges.emplace_back(cell_range_begin, i);
+                    cell_range_begin = kv_self.size;
+                }
             }
         }
-    }
-}
+        if (cell_range_begin != kv_self.size) {
+            cell_ranges.emplace_back(cell_range_begin, kv_self.size);
+        }
 
-size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst) {
-    llama_data_buffer_context data_ctx(dst);
-    llama_state_get_data_internal(ctx, &data_ctx);
+        // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
+        uint32_t cell_count_check = 0;
+        for (const auto & range : cell_ranges) {
+            cell_count_check += range.second - range.first;
+        }
+        GGML_ASSERT(cell_count == cell_count_check);
 
-    return data_ctx.get_size_written();
-}
+        write(&cell_count, sizeof(cell_count));
 
-// Sets the state reading from the specified source address
-size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
-    llama_synchronize(ctx);
+        write_kv_cache_meta(kv_self, cell_ranges, seq_id);
+        write_kv_cache_data(ctx, cell_ranges);
+    }
+};
 
-    const uint8_t * inp = src;
+struct llama_data_read {
+    virtual const uint8_t * read(size_t size) = 0;
+    virtual void read_to(void * dst, size_t size) = 0;
+    virtual size_t get_size_read() = 0;
+    virtual ~llama_data_read() = default;
 
-    // set rng
-    {
-        size_t rng_size;
-        memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size);
+    void read_string(std::string & str) {
+        uint32_t str_size;
+        read_to(&str_size, sizeof(str_size));
 
-        GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
+        str.assign((const char *) read(str_size), str_size);
+    }
+
+    // validate model information
+    void read_model_info(const struct llama_context * ctx) {
+        std::string cur_arch_str = LLM_ARCH_NAMES.at(ctx->model.arch);
+        std::string arch_str;
+        read_string(arch_str);
+        if (cur_arch_str != arch_str) {
+            throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
+        }
+        // TODO: add more info which needs to be identical but which is not verified otherwise
+    }
 
-        std::string rng_str((const char *)inp, rng_size); inp += rng_size;
+    void read_rng(std::mt19937 & rng) {
+        std::string rng_str;
+        read_string(rng_str);
 
         std::istringstream rng_ss(rng_str);
-        rng_ss >> ctx->sampling.rng;
+        rng_ss >> rng;
 
-        GGML_ASSERT(!rng_ss.fail());
+        if (rng_ss.fail()) {
+            throw std::runtime_error("failed to load RNG state");
+        }
     }
 
-    // set output ids
-    {
-        size_t n_outputs;
+    void read_output_ids(struct llama_context * ctx) {
         std::vector<int32_t> output_pos;
 
-        memcpy(&n_outputs, inp, sizeof(n_outputs)); inp += sizeof(n_outputs);
+        uint32_t n_outputs;
+        read_to(&n_outputs, sizeof(n_outputs));
 
-        GGML_ASSERT(n_outputs <= llama_output_reserve(*ctx, n_outputs));
+        if (n_outputs > llama_output_reserve(*ctx, n_outputs)) {
+            throw std::runtime_error("could not reserve outputs");
+        }
 
         if (n_outputs) {
             output_pos.resize(n_outputs);
-            memcpy(output_pos.data(), inp, n_outputs * sizeof(int32_t));
-            inp += n_outputs * sizeof(int32_t);
+            read_to(output_pos.data(), n_outputs * sizeof(int32_t));
 
             for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
                 int32_t id = output_pos[i];
-                GGML_ASSERT((uint32_t) id < ctx->cparams.n_batch);
+                if ((uint32_t) id >= ctx->cparams.n_batch) {
+                    throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->cparams.n_batch));
+                }
                 ctx->output_ids[id] = i;
             }
 
@@ -17630,128 +17612,434 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
         }
     }
 
-    // set logits
-    {
-        size_t logits_size;
-
-        memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size);
+    void read_logits(struct llama_context * ctx) {
+        uint64_t logits_size;
+        read_to(&logits_size, sizeof(logits_size));
 
-        GGML_ASSERT(ctx->logits_size >= logits_size);
+        if (ctx->logits_size < logits_size) {
+            throw std::runtime_error("logits buffer too small");
+        }
 
         if (logits_size) {
-            memcpy(ctx->logits, inp, logits_size * sizeof(float));
-            inp += logits_size * sizeof(float);
+            read_to(ctx->logits, logits_size * sizeof(float));
         }
     }
 
-    // set embeddings
-    {
-        size_t embeddings_size;
-
-        memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size);
+    void read_embeddings(struct llama_context * ctx) {
+        uint64_t embeddings_size;
+        read_to(&embeddings_size, sizeof(embeddings_size));
 
-        GGML_ASSERT(ctx->embd_size >= embeddings_size);
+        if (ctx->embd_size < embeddings_size) {
+            throw std::runtime_error("embeddings buffer too small");
+        }
 
         if (embeddings_size) {
-            memcpy(ctx->embd, inp, embeddings_size * sizeof(float));
-            inp += embeddings_size * sizeof(float);
+            read_to(ctx->embd, embeddings_size * sizeof(float));
         }
     }
 
-    // set kv cache
-    {
-        const auto & kv_self = ctx->kv_self;
-        const auto & hparams = ctx->model.hparams;
+    bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) {
+        struct llama_kv_cache & kv_self = ctx->kv_self;
 
-        const uint32_t n_layer      = hparams.n_layer;
+        if (dest_seq_id != -1) {
+            // single sequence
 
-        size_t   kv_buf_size;
-        uint32_t kv_head;
-        uint32_t kv_size;
-        uint32_t kv_used;
-        uint32_t v_trans;
+            llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
+
+            llama_batch batch = llama_batch_init(cell_count, 0, 1);
+            batch.n_tokens = cell_count;
+            for (uint32_t i = 0; i < cell_count; ++i) {
+                llama_pos pos;
+                uint32_t n_seq_id;
 
-        memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
-        memcpy(&kv_head,     inp, sizeof(kv_head));     inp += sizeof(kv_head);
-        memcpy(&kv_size,     inp, sizeof(kv_size));     inp += sizeof(kv_size);
-        memcpy(&kv_used,     inp, sizeof(kv_used));     inp += sizeof(kv_used);
-        memcpy(&v_trans,     inp, sizeof(v_trans));     inp += sizeof(v_trans);
+                read_to(&pos, sizeof(pos));
+                read_to(&n_seq_id, sizeof(n_seq_id));
 
-        GGML_ASSERT(kv_self.v_trans == (bool) v_trans); // incompatible V transposition
+                if (n_seq_id != 0) {
+                    LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
+                    return false;
+                }
 
-        if (kv_self.size != kv_size) {
-            // the KV cache needs to be big enough to load all the KV cells from the saved state
-            GGML_ASSERT(kv_self.size >= kv_head);
+                batch.pos[i] = pos;
+                batch.n_seq_id[i] = 1;
+                batch.seq_id[i][0] = dest_seq_id;
+            }
+            if (!llama_kv_cache_find_slot(kv_self, batch)) {
+                llama_batch_free(batch);
+                LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
+                return false;
+            }
+
+            // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
+            // Assume that this is one contiguous block of cells
+            GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
+            GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
+            GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
+            GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
+            GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
 
-            LLAMA_LOG_INFO("%s: state contains %d KV cells, was saved with kv_size=%d, but is loaded with kv_size=%d (fine, but different)\n",
-                __func__, kv_head, kv_size, kv_self.size);
+            // Cleanup
+            llama_batch_free(batch);
+        } else {
+            // whole KV cache restore
+
+            if (cell_count > kv_self.size) {
+                LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
+                return false;
+            }
+
+            llama_kv_cache_clear(kv_self);
+
+            for (uint32_t i = 0; i < cell_count; ++i) {
+                llama_kv_cell & cell = kv_self.cells[i];
+
+                llama_pos pos;
+                uint32_t  n_seq_id;
+
+                read_to(&pos,      sizeof(pos));
+                read_to(&n_seq_id, sizeof(n_seq_id));
+
+                cell.pos = pos;
+
+                for (uint32_t j = 0; j < n_seq_id; ++j) {
+                    llama_seq_id seq_id;
+                    read_to(&seq_id, sizeof(seq_id));
+
+                    if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
+                        LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
+                        return false;
+                    }
+
+                    cell.seq_id.insert(seq_id);
+                }
+            }
+
+            kv_self.head = 0;
+            kv_self.used = cell_count;
+        }
+
+        return true;
+    }
+
+    bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
+        const struct llama_hparams & hparams = ctx->model.hparams;
+        struct llama_kv_cache & kv_self = ctx->kv_self;
+        uint32_t v_trans;
+        uint32_t n_layer;
+        read_to(&v_trans, sizeof(v_trans));
+        read_to(&n_layer, sizeof(n_layer));
+
+        if (n_layer != hparams.n_layer) {
+            LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
+            return false;
+        }
+        if (cell_count > kv_self.size) {
+            LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
+            return false;
+        }
+        if (kv_self.v_trans != (bool) v_trans) {
+            LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
+            return false;
         }
 
-        llama_kv_cache_clear(ctx);
+        // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
 
-        if (kv_buf_size) {
-            const size_t pre_kv_buf_size = inp - src;
+            // Read type of key
+            int32_t k_type_i_ref;
+            read_to(&k_type_i_ref, sizeof(k_type_i_ref));
+            const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
+            if (k_type_i != k_type_i_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
+                return false;
+            }
+
+            // Read row size of key
+            uint64_t k_size_row_ref;
+            read_to(&k_size_row_ref, sizeof(k_size_row_ref));
+            const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
+            if (k_size_row != k_size_row_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
+                return false;
+            }
 
-            GGML_ASSERT(kv_self.total_size() >= kv_buf_size);
+            if (cell_count) {
+                // Read and set the keys for the whole cell range
+                ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row);
+            }
+        }
 
-            for (int il = 0; il < (int) n_layer; ++il) {
-                const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+        if (!kv_self.v_trans) {
+            for (uint32_t il = 0; il < n_layer; ++il) {
                 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 
-                const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
+                // Read type of value
+                int32_t v_type_i_ref;
+                read_to(&v_type_i_ref, sizeof(v_type_i_ref));
+                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+                if (v_type_i != v_type_i_ref) {
+                    LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+                    return false;
+                }
 
-                ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
-                inp += k_size;
+                // Read row size of value
+                uint64_t v_size_row_ref;
+                read_to(&v_size_row_ref, sizeof(v_size_row_ref));
+                const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
+                if (v_size_row != v_size_row_ref) {
+                    LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
+                    return false;
+                }
 
-                if (kv_self.recurrent || !kv_self.v_trans) {
-                    // v is contiguous for recurrent models
-                    // TODO: use other tensors for state models than k and v
-                    const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
+                if (cell_count) {
+                    // Read and set the values for the whole cell range
+                    ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
+                }
+            }
+        } else {
+            // For each layer, read the values for each cell (transposed)
+            for (uint32_t il = 0; il < n_layer; ++il) {
+                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 
-                    ggml_backend_tensor_set(kv_self.v_l[il], inp, 0, v_size);
-                    inp += v_size;
-                    continue;
+                // Read type of value
+                int32_t v_type_i_ref;
+                read_to(&v_type_i_ref, sizeof(v_type_i_ref));
+                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+                if (v_type_i != v_type_i_ref) {
+                    LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+                    return false;
                 }
 
-                // v is not contiguous, copy row by row
-                const size_t v_row_size   = ggml_row_size(kv_self.v_l[il]->type, kv_head);
-                const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_self.size);
+                // Read element size of value
+                uint32_t v_size_el_ref;
+                read_to(&v_size_el_ref, sizeof(v_size_el_ref));
+                const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
+                if (v_size_el != v_size_el_ref) {
+                    LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
+                    return false;
+                }
+
+                // Read GQA embedding size
+                uint32_t n_embd_v_gqa_ref;
+                read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
+                if (n_embd_v_gqa != n_embd_v_gqa_ref) {
+                    LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
+                    return false;
+                }
 
-                for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
-                    ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size);
-                    inp += v_row_size;
+                if (cell_count) {
+                    // For each row in the transposed matrix, read the values for the whole cell range
+                    for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                        const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el;
+                        ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
+                    }
                 }
             }
-            GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size);
         }
+        return true;
+    }
 
-        ctx->kv_self.head = kv_head;
-        ctx->kv_self.used = kv_used;
+    void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
+        uint32_t cell_count;
+        read_to(&cell_count, sizeof(cell_count));
 
-        for (uint32_t i = 0; i < kv_head; ++i) {
-            llama_pos pos;
-            size_t    seq_id_size;
+        bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count);
 
-            memcpy(&pos,         inp, sizeof(pos));         inp += sizeof(pos);
-            memcpy(&seq_id_size, inp, sizeof(seq_id_size)); inp += sizeof(seq_id_size);
+        if (!res) {
+            if (seq_id == -1) {
+                llama_kv_cache_clear(ctx);
+            } else {
+                llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
+            }
+            throw std::runtime_error("failed to restore kv cache");
+        }
+    }
+};
 
-            ctx->kv_self.cells[i].pos = pos;
+struct llama_data_write_dummy : llama_data_write {
+    size_t size_written = 0;
 
-            llama_seq_id seq_id;
+    llama_data_write_dummy() {}
 
-            for (size_t j = 0; j < seq_id_size; ++j) {
-                memcpy(&seq_id, inp, sizeof(seq_id)); inp += sizeof(seq_id);
-                ctx->kv_self.cells[i].seq_id.insert(seq_id);
-            }
+    // TODO: avoid unnecessary calls to ggml_backend_tensor_get in a dummy context
+
+    void write(const void * /* src */, size_t size) override {
+        size_written += size;
+    }
+
+    size_t get_size_written() override {
+        return size_written;
+    }
+};
+
+struct llama_data_write_buffer : llama_data_write {
+    uint8_t * ptr;
+    size_t buf_size = 0;
+    size_t size_written = 0;
+
+    llama_data_write_buffer(uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
+
+    void write(const void * src, size_t size) override {
+        if (size > buf_size) {
+            throw std::runtime_error("unexpectedly reached end of buffer");
         }
+        memcpy(ptr, src, size);
+        ptr += size;
+        size_written += size;
+        buf_size -= size;
     }
 
-    const size_t nread    = inp - src;
-    const size_t max_size = llama_state_get_size(ctx);
+    size_t get_size_written() override {
+        return size_written;
+    }
+};
 
-    GGML_ASSERT(nread <= max_size);
+struct llama_data_read_buffer : llama_data_read {
+    const uint8_t * ptr;
+    size_t buf_size = 0;
+    size_t size_read = 0;
 
-    return nread;
+    llama_data_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
+
+    const uint8_t * read(size_t size) override {
+        const uint8_t * base_ptr = ptr;
+        if (size > buf_size) {
+            throw std::runtime_error("unexpectedly reached end of buffer");
+        }
+        ptr += size;
+        size_read += size;
+        buf_size -= size;
+        return base_ptr;
+    }
+
+    void read_to(void * dst, size_t size) override {
+        memcpy(dst, read(size), size);
+    }
+
+    size_t get_size_read() override {
+        return size_read;
+    }
+};
+
+struct llama_data_write_file : llama_data_write {
+    llama_file * file;
+    size_t size_written = 0;
+
+    llama_data_write_file(llama_file * f) : file(f) {}
+
+    void write(const void * src, size_t size) override {
+        file->write_raw(src, size);
+        size_written += size;
+    }
+
+    size_t get_size_written() override {
+        return size_written;
+    }
+};
+
+struct llama_data_read_file : llama_data_read {
+    llama_file * file;
+    size_t size_read = 0;
+    std::vector<uint8_t> temp_buffer;
+
+    llama_data_read_file(llama_file * f) : file(f) {}
+
+    void read_to(void * dst, size_t size) override {
+        file->read_raw(dst, size);
+        size_read += size;
+    }
+
+    const uint8_t * read(size_t size) override {
+        temp_buffer.resize(size);
+        read_to(temp_buffer.data(), size);
+        return temp_buffer.data();
+    }
+
+    size_t get_size_read() override {
+        return size_read;
+    }
+};
+
+/** copy state data into either a buffer or file depending on the passed in context
+ *
+ * file context:
+ * llama_file file("/path", "wb");
+ * llama_data_write_file data_ctx(&file);
+ * llama_state_get_data_internal(ctx, data_ctx);
+ *
+ * buffer context:
+ * std::vector<uint8_t> buf(max_size, 0);
+ * llama_data_write_buffer data_ctx(buf.data(), max_size);
+ * llama_state_get_data_internal(ctx, data_ctx);
+ *
+*/
+static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx) {
+    llama_synchronize(ctx);
+
+    data_ctx.write_model_info(ctx);
+
+    data_ctx.write_rng(ctx->sampling.rng);
+
+    // copy outputs
+    data_ctx.write_output_ids(ctx);
+    data_ctx.write_logits(ctx);
+    data_ctx.write_embeddings(ctx);
+
+    data_ctx.write_kv_cache(ctx);
+
+    return data_ctx.get_size_written();
+}
+
+size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) {
+    llama_data_write_buffer data_ctx(dst, size);
+    try {
+        return llama_state_get_data_internal(ctx, data_ctx);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
+        return 0;
+    }
+}
+
+// Returns the *actual* size of the state.
+// Intended to be used when saving to state to a buffer.
+size_t llama_state_get_size(struct llama_context * ctx) {
+    llama_data_write_dummy data_ctx;
+    try {
+        return llama_state_get_data_internal(ctx, data_ctx);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
+        return 0;
+    }
+}
+
+static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) {
+    llama_synchronize(ctx);
+
+    data_ctx.read_model_info(ctx);
+
+    // set rng
+    data_ctx.read_rng(ctx->sampling.rng);
+
+    // set outputs
+    data_ctx.read_output_ids(ctx);
+    data_ctx.read_logits(ctx);
+    data_ctx.read_embeddings(ctx);
+
+    data_ctx.read_kv_cache(ctx);
+
+    return data_ctx.get_size_read();
+}
+
+// Sets the state reading from the specified source address
+size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) {
+    llama_data_read_buffer data_ctx(src, size);
+    try {
+        return llama_state_set_data_internal(ctx, data_ctx);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
+        return 0;
+    }
 }
 
 static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
@@ -17763,15 +18051,7 @@ static bool llama_state_load_file_internal(struct llama_context * ctx, const cha
         const uint32_t version = file.read_u32();
 
         if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
-            LLAMA_LOG_ERROR("%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
-            return false;
-        }
-
-        llama_hparams session_hparams;
-        file.read_raw(&session_hparams, sizeof(llama_hparams));
-
-        if (session_hparams != ctx->model.hparams) {
-            LLAMA_LOG_INFO("%s : model hparams didn't match from session file!\n", __func__);
+            LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
             return false;
         }
     }
@@ -17781,7 +18061,7 @@ static bool llama_state_load_file_internal(struct llama_context * ctx, const cha
         const uint32_t n_token_count = file.read_u32();
 
         if (n_token_count > n_token_capacity) {
-            LLAMA_LOG_ERROR("%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
+            LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
             return false;
         }
 
@@ -17792,19 +18072,15 @@ static bool llama_state_load_file_internal(struct llama_context * ctx, const cha
     // restore the context state
     {
         const size_t n_state_size_cur = file.size - file.tell();
-        const size_t n_state_size_max = llama_state_get_size(ctx);
 
-        if (n_state_size_cur > n_state_size_max) {
-            LLAMA_LOG_ERROR("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur);
+        llama_data_read_file data_ctx(&file);
+        const size_t n_read = llama_state_set_data_internal(ctx, data_ctx);
+
+        if (n_read != n_state_size_cur) {
+            LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
             return false;
         }
-
-        std::vector<uint8_t> state_data(n_state_size_max);
-        file.read_raw(state_data.data(), n_state_size_cur);
-
-        llama_state_set_data(ctx, state_data.data());
     }
-
     return true;
 }
 
@@ -17812,7 +18088,7 @@ bool llama_state_load_file(struct llama_context * ctx, const char * path_session
     try {
         return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
     } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("error loading session file: %s\n", err.what());
+        LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
         return false;
     }
 }
@@ -17823,15 +18099,13 @@ static bool llama_state_save_file_internal(struct llama_context * ctx, const cha
     file.write_u32(LLAMA_SESSION_MAGIC);
     file.write_u32(LLAMA_SESSION_VERSION);
 
-    file.write_raw(&ctx->model.hparams, sizeof(llama_hparams));
-
     // save the prompt
     file.write_u32((uint32_t) n_token_count);
     file.write_raw(tokens, sizeof(llama_token) * n_token_count);
 
     // save the context state using stream saving
-    llama_data_file_context data_ctx(&file);
-    llama_state_get_data_internal(ctx, &data_ctx);
+    llama_data_write_file data_ctx(&file);
+    llama_state_get_data_internal(ctx, data_ctx);
 
     return true;
 }
@@ -17840,401 +18114,50 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session
     try {
         return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count);
     } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("error saving session file: %s\n", err.what());
+        LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
         return false;
     }
 }
 
-size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) {
-    // save the size of size_t as a uint32_t for safety check
-    const size_t size_t_size_size = sizeof(uint32_t);
-
-    // other values
-    const size_t s_cell_count_size = sizeof(uint32_t);
-    const size_t s_layer_count_size = sizeof(uint32_t);
-    const size_t n_embd_v_gqa_size = sizeof(uint32_t);
-
-    size_t s_cell_count = 0;
-    size_t s_cell_data_size = 0;
-    const auto & kv_self = ctx->kv_self;
-    const auto & hparams = ctx->model.hparams;
-
-    const uint32_t n_layer = hparams.n_layer;
-
-    for (uint32_t i = 0; i < kv_self.size; ++i) {
-        const auto & cell = kv_self.cells[i];
-        if (cell.seq_id.count(seq_id) > 0) {
-            ++s_cell_count;
-            s_cell_data_size += sizeof(llama_pos);
-        }
-    }
-
-    for (int il = 0; il < (int)n_layer; ++il) {
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
-        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
-        // types of keys and values
-        s_cell_data_size += sizeof(int32_t) * 2;
-        // k_size_row and v_size_el values of layer
-        s_cell_data_size += sizeof(size_t) * 2;
-
-        // keys
-        const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
-        s_cell_data_size += k_size_row * s_cell_count;
-
-        // values (transposed)
-        const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
-        s_cell_data_size += v_size_el * s_cell_count * n_embd_v_gqa;
-    }
-
-    const size_t s_total = (
-        size_t_size_size +
-        s_cell_count_size +
-        s_layer_count_size +
-        n_embd_v_gqa_size +
-        s_cell_data_size
-        );
-
-    return s_total;
-}
-
-static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_context & data_ctx, llama_seq_id seq_id) {
+static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
     llama_synchronize(ctx);
 
-    const auto & kv_self = ctx->kv_self;
-    GGML_ASSERT(!kv_self.recurrent); // not implemented
-
-    // Save the size of size_t as a uint32_t for safety check
-    const uint32_t size_t_size = sizeof(size_t);
-    data_ctx.write(&size_t_size, sizeof(size_t_size));
-
-    std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
-    uint32_t cell_count = 0;
-
-    // Count the number of cells with the specified seq_id
-    // Find all the ranges of cells with this seq id
-    {
-        uint32_t cell_range_begin = kv_self.size;
-        for (uint32_t i = 0; i < kv_self.size; ++i) {
-            const auto & cell = kv_self.cells[i];
-            if (cell.has_seq_id(seq_id)) {
-                ++cell_count;
-                if (cell_range_begin == kv_self.size) {
-                    cell_range_begin = i;
-                }
-            }
-            else {
-                if (cell_range_begin != kv_self.size) {
-                    cell_ranges.emplace_back(cell_range_begin, i);
-                    cell_range_begin = kv_self.size;
-                }
-            }
-        }
-        if (cell_range_begin != kv_self.size) {
-            cell_ranges.emplace_back(cell_range_begin, kv_self.size);
-        }
-
-        // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
-        uint32_t cell_count_check = 0;
-        for (const auto & range : cell_ranges) {
-            cell_count_check += range.second - range.first;
-        }
-        GGML_ASSERT(cell_count == cell_count_check);
-    }
-
-    // Write the cell count
-    data_ctx.write(&cell_count, sizeof(cell_count));
-
-    const auto & hparams = ctx->model.hparams;
-    const uint32_t n_layer = hparams.n_layer;
-
-    // Write the layer count
-    data_ctx.write(&n_layer, sizeof(n_layer));
-
-    // Write n_embd_v_gqa (reference value)
-    {
-        const uint32_t n_embd_v_gqa_ref = hparams.n_embd_v_gqa() + hparams.n_embd_k_s();
-        data_ctx.write(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
-    }
-
-    // Iterate the ranges and write all the pos (this is the token position in the prompt)
-    for (const auto & range : cell_ranges) {
-        for (uint32_t i = range.first; i < range.second; ++i) {
-            const auto & cell = kv_self.cells[i];
-            data_ctx.write(&cell.pos, sizeof(cell.pos));
-        }
-    }
-
-    // Iterate and write all the keys first, each row is a cell
-    // Get whole range at a time
-    std::vector<uint8_t> tmp_buf;
-    for (int il = 0; il < (int)n_layer; ++il) {
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
-
-        // Write key type
-        const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
-        data_ctx.write(&k_type_i, sizeof(k_type_i));
-
-        // Write row size of key
-        const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
-        data_ctx.write(&k_size_row, sizeof(k_size_row));
-
-        // Read each range of cells of k_size length each into tmp_buf and write out
-        for (const auto & range : cell_ranges) {
-            const size_t range_size = range.second - range.first;
-            tmp_buf.resize(range_size * k_size_row);
-            ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), range.first * k_size_row, range_size * k_size_row);
-            data_ctx.write(tmp_buf.data(), tmp_buf.size());
-        }
-    }
-
-    // TODO: simplify, reduce copy-paste
-    if (!kv_self.v_trans) {
-        for (int il = 0; il < (int)n_layer; ++il) {
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
-            // Write value type
-            const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-            data_ctx.write(&v_type_i, sizeof(v_type_i));
-
-            // Write row size of value
-            const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
-            data_ctx.write(&v_size_row, sizeof(v_size_row));
-
-            // Read each range of cells of v_size length each into tmp_buf and write out
-            for (const auto & range : cell_ranges) {
-                const size_t range_size = range.second - range.first;
-                tmp_buf.resize(range_size * v_size_row);
-                ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row);
-                data_ctx.write(tmp_buf.data(), tmp_buf.size());
-            }
-        }
-    } else {
-        // For the values, they are transposed, so we also need the element size and get the element ranges from each row
-        const uint32_t kv_size = kv_self.size;
-        for (int il = 0; il < (int)n_layer; ++il) {
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
-            // Write value type
-            const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-            data_ctx.write(&v_type_i, sizeof(v_type_i));
-
-            // Write element size
-            const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
-            data_ctx.write(&v_size_el, sizeof(v_size_el));
-
-            // For each row, we get the element values of each cell
-            for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                // Read each range of cells of v_size_el length each into tmp_buf and write out
-                for (const auto & range : cell_ranges) {
-                    const size_t range_size = range.second - range.first;
-                    const size_t src_offset = (range.first + j * kv_size) * v_size_el;
-                    tmp_buf.resize(range_size * v_size_el);
-                    ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
-                    data_ctx.write(tmp_buf.data(), tmp_buf.size());
-                }
-            }
-        }
-    }
+    data_ctx.write_kv_cache(ctx, seq_id);
 
     return data_ctx.get_size_written();
 }
 
-size_t llama_state_seq_get_data(struct llama_context* ctx, uint8_t* dst, llama_seq_id seq_id) {
-    llama_data_buffer_context data_ctx(dst);
+size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) {
+    llama_data_write_dummy data_ctx;
     return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
 }
 
-size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) {
-    llama_synchronize(ctx);
-
-    auto & kv_self = ctx->kv_self;
-    GGML_ASSERT(!kv_self.recurrent); // not implemented
-
-    // Wipe the slot
-    llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-
-    const uint8_t * inp = src;
-
-    // Read size of size_t
-    uint32_t size_t_size;
-    memcpy(&size_t_size, inp, sizeof(size_t_size));
-    inp += sizeof(size_t_size);
-    if (size_t_size != sizeof(size_t)) {
-        LLAMA_LOG_ERROR("%s: size_t size mismatch\n", __func__);
+size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
+    llama_data_write_buffer data_ctx(dst, size);
+    try {
+        return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what());
         return 0;
     }
+}
 
-    // Read the cell count
-    uint32_t cell_count;
-    memcpy(&cell_count, inp, sizeof(cell_count));
-    inp += sizeof(cell_count);
-
-    // Read the layer count
-    uint32_t n_layer_ref;
-    memcpy(&n_layer_ref, inp, sizeof(n_layer_ref));
-    inp += sizeof(n_layer_ref);
-
-    // Read n_embd_v_gqa
-    uint32_t n_embd_v_gqa_ref;
-    memcpy(&n_embd_v_gqa_ref, inp, sizeof(n_embd_v_gqa_ref));
-    inp += sizeof(n_embd_v_gqa_ref);
+static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
+    llama_synchronize(ctx);
 
-    // Sanity check model compatibility
-    const auto & hparams = ctx->model.hparams;
-    const uint32_t n_layer = hparams.n_layer;
+    data_ctx.read_kv_cache(ctx, dest_seq_id);
 
-    if (n_layer != n_layer_ref) {
-        LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref);
-        return 0;
-    }
+    return data_ctx.get_size_read();
+}
 
-    if (hparams.n_embd_v_gqa() != n_embd_v_gqa_ref) {
-        LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, hparams.n_embd_v_gqa(), n_embd_v_gqa_ref);
+size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) {
+    llama_data_read_buffer data_ctx(src, size);
+    try {
+        return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what());
         return 0;
     }
-
-    // Allocate the new cells for the slot
-    if (cell_count) {
-        llama_batch batch = llama_batch_init(cell_count, 0, 1);
-        batch.n_tokens = cell_count;
-        for (uint32_t i = 0; i < cell_count; ++i) {
-            llama_pos pos;
-            memcpy(&pos, inp, sizeof(pos));
-            inp += sizeof(pos);
-
-            batch.pos[i] = pos;
-            batch.n_seq_id[i] = 1;
-            batch.seq_id[i][0] = dest_seq_id;
-        }
-        if (!llama_kv_cache_find_slot(kv_self, batch)) {
-            llama_batch_free(batch);
-            LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
-            return 0;
-        }
-
-        // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
-        // Assume that this is one contiguous block of cells
-        GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
-        GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
-        GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
-        GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
-        GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
-
-        // Cleanup
-        llama_batch_free(batch);
-    }
-
-    const uint32_t kv_size = kv_self.size;
-    const uint32_t kv_head = kv_self.head;
-
-    // For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo
-    for (int il = 0; il < (int)n_layer; ++il) {
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
-
-        // Read type of key
-        int32_t k_type_i_ref;
-        memcpy(&k_type_i_ref, inp, sizeof(k_type_i_ref));
-        inp += sizeof(k_type_i_ref);
-        const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
-        if (k_type_i != k_type_i_ref) {
-            llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-            LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
-            return 0;
-        }
-
-        // Read row size of key
-        size_t k_size_row_ref;
-        memcpy(&k_size_row_ref, inp, sizeof(k_size_row_ref));
-        inp += sizeof(k_size_row_ref);
-        const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
-        if (k_size_row != k_size_row_ref) {
-            llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-            LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, k_size_row_ref, il);
-            return 0;
-        }
-
-        if (cell_count) {
-            // Read and set the keys for the whole cell range
-            ggml_backend_tensor_set(kv_self.k_l[il], inp, kv_head * k_size_row, cell_count * k_size_row);
-            inp += cell_count * k_size_row;
-        }
-    }
-
-    // TODO: simplify, reduce copy-paste
-    if (!kv_self.v_trans) {
-        for (int il = 0; il < (int)n_layer; ++il) {
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
-            // Read type of value
-            int32_t v_type_i_ref;
-            memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
-            inp += sizeof(v_type_i_ref);
-            const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-            if (v_type_i != v_type_i_ref) {
-                llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
-                return 0;
-            }
-
-            // Read row size of value
-            size_t v_size_row_ref;
-            memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref));
-            inp += sizeof(v_size_row_ref);
-            const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
-            if (v_size_row != v_size_row_ref) {
-                llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-                LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il);
-                return 0;
-            }
-
-            if (cell_count) {
-                // Read and set the values for the whole cell range
-                ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row);
-                inp += cell_count * v_size_row;
-            }
-        }
-    } else {
-        // For each layer, read the values for each cell (transposed)
-        for (int il = 0; il < (int)n_layer; ++il) {
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
-            // Read type of value
-            int32_t v_type_i_ref;
-            memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
-            inp += sizeof(v_type_i_ref);
-            const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-            if (v_type_i != v_type_i_ref) {
-                llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
-                return 0;
-            }
-
-            // Read element size of value
-            size_t v_size_el_ref;
-            memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref));
-            inp += sizeof(v_size_el_ref);
-            const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
-            if (v_size_el != v_size_el_ref) {
-                llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-                LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il);
-                return 0;
-            }
-
-            if (cell_count) {
-                // For each row in the transposed matrix, read the values for the whole cell range
-                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    const size_t dst_offset = (kv_head + j * kv_size) * v_size_el;
-                    ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el);
-                    inp += cell_count * v_size_el;
-                }
-            }
-        }
-    }
-
-    const size_t nread = inp - src;
-
-    return nread;
 }
 
 static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
@@ -18244,11 +18167,11 @@ static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, con
     file.write_u32(LLAMA_STATE_SEQ_VERSION);
 
     // save the prompt
-    file.write_u32((uint32_t)n_token_count);
+    file.write_u32((uint32_t) n_token_count);
     file.write_raw(tokens, sizeof(llama_token) * n_token_count);
 
     // save the context state using stream saving
-    llama_data_file_context data_ctx(&file);
+    llama_data_write_file data_ctx(&file);
     llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
 
     const size_t res = file.tell();
@@ -18286,9 +18209,8 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con
     // restore the context state
     {
         const size_t state_size = file.size - file.tell();
-        std::vector<uint8_t> state_data(state_size);
-        file.read_raw(state_data.data(), state_size);
-        const size_t nread = llama_state_seq_set_data(ctx, state_data.data(), dest_seq_id);
+        llama_data_read_file data_ctx(&file);
+        const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
         if (!nread) {
             LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
             return 0;
@@ -18304,7 +18226,7 @@ size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepa
     try {
         return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count);
     } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("error saving sequence state file: %s\n", err.what());
+        LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what());
         return 0;
     }
 }
@@ -18313,7 +18235,7 @@ size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepa
     try {
         return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out);
     } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("error loading sequence state file: %s\n", err.what());
+        LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what());
         return 0;
     }
 }