]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : minimize size used for state save/load (#4820)
authorDavid Friehs <redacted>
Sat, 13 Jan 2024 16:29:43 +0000 (17:29 +0100)
committerGitHub <redacted>
Sat, 13 Jan 2024 16:29:43 +0000 (18:29 +0200)
* examples : save-load-state: save only required state

* llama : only reserve n_vocab * n_batch at most for logits

llama_decode asserts that only n_batch tokens are passed each call, and
n_ctx is expected to be bigger than n_batch.

* llama : always reserve n_vocab * n_batch for logits

llama_context de-serialization breaks if the contexts have differing
capacity for logits and llama_decode will at maximum resize to
n_vocab * n_batch.

* llama : only save and restore used logits

for batch sizes of 512 this reduces save state in the best case by
around 62 MB, which can be a lot if planning to save on each message
to allow regenerating messages.

* llama : use ostringstream and istringstream for save and load

* llama : serialize rng into minimum amount of space required

* llama : break session version due to serialization changes

examples/save-load-state/save-load-state.cpp
llama.cpp
llama.h

index 48d80111010dfa268ddab78fec55651b3e612981..ef952e2bd987caf203649049f4fc81f928784cb4 100644 (file)
@@ -45,13 +45,13 @@ int main(int argc, char ** argv) {
     // save state (rng, logits, embedding and kv_cache) to file
     {
         std::vector<uint8_t> state_mem(llama_get_state_size(ctx));
+        const size_t written = llama_copy_state_data(ctx, state_mem.data());
 
-        {
-            FILE *fp_write = fopen("dump_state.bin", "wb");
-            llama_copy_state_data(ctx, state_mem.data()); // could also copy directly to memory mapped file
-            fwrite(state_mem.data(), 1, state_mem.size(), fp_write);
-            fclose(fp_write);
-        }
+        FILE *fp_write = fopen("dump_state.bin", "wb");
+        fwrite(state_mem.data(), 1, written, fp_write);
+        fclose(fp_write);
+
+        fprintf(stderr, "%s : serialized state into %zd out of a maximum of %zd bytes\n", __func__, written, state_mem.size());
     }
 
     // save state (last tokens)
@@ -100,18 +100,17 @@ int main(int argc, char ** argv) {
         std::vector<uint8_t> state_mem(llama_get_state_size(ctx2));
 
         FILE * fp_read = fopen("dump_state.bin", "rb");
+        const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
+        fclose(fp_read);
 
-        const size_t ret = fread(state_mem.data(), 1, state_mem.size(), fp_read);
-        if (ret != state_mem.size()) {
+        if (read != llama_set_state_data(ctx2, state_mem.data())) {
             fprintf(stderr, "\n%s : failed to read state\n", __func__);
             llama_free(ctx2);
             llama_free_model(model);
             return 1;
         }
 
-        llama_set_state_data(ctx2, state_mem.data());
-
-        fclose(fp_read);
+        fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
     }
 
     // restore state (last tokens)
index 1d2eb569f01ff8b5c2acbd7f66640c58572f18d8..2754560884e5f662cc99ad024b0ef55dafe69dee 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -9379,12 +9379,8 @@ struct llama_context * llama_new_context_with_model(
                 ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
         }
 
-        // resized during inference
-        if (params.logits_all) {
-            ctx->logits.reserve(cparams.n_ctx*hparams.n_vocab);
-        } else {
-            ctx->logits.reserve(hparams.n_vocab);
-        }
+        // resized during inference, reserve maximum
+        ctx->logits.reserve(hparams.n_vocab*cparams.n_batch);
 
         if (params.embedding){
             ctx->embedding.resize(hparams.n_embd);
@@ -9731,8 +9727,8 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
     // for reference, std::mt19937(1337) serializes to 6701 bytes.
     const size_t s_rng_size        = sizeof(size_t);
     const size_t s_rng             = LLAMA_MAX_RNG_STATE;
-    const size_t s_logits_capacity = sizeof(size_t);
     const size_t s_logits_size     = sizeof(size_t);
+    // assume worst case for logits although only currently set ones are serialized
     const size_t s_logits          = ctx->logits.capacity() * sizeof(float);
     const size_t s_embedding_size  = sizeof(size_t);
     const size_t s_embedding       = ctx->embedding.size() * sizeof(float);
@@ -9743,7 +9739,6 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
     const size_t s_total = (
         + s_rng_size
         + s_rng
-        + s_logits_capacity
         + s_logits_size
         + s_logits
         + s_embedding_size
@@ -9812,37 +9807,27 @@ struct llama_data_file_context : llama_data_context {
 static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
     // copy rng
     {
-        std::stringstream rng_ss;
+        std::ostringstream rng_ss;
         rng_ss << ctx->rng;
 
-        const size_t rng_size = rng_ss.str().size();
-        char rng_buf[LLAMA_MAX_RNG_STATE];
+        const std::string & rng_str = rng_ss.str();
+        const size_t        rng_size = rng_str.size();
 
-        memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE);
-        memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
+        GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
 
-        data_ctx->write(&rng_size,   sizeof(rng_size));
-        data_ctx->write(&rng_buf[0], LLAMA_MAX_RNG_STATE);
+        data_ctx->write(&rng_size,      sizeof(rng_size));
+        data_ctx->write(rng_str.data(), rng_size);
     }
 
     // copy logits
     {
-        const size_t logits_cap  = ctx->logits.capacity();
         const size_t logits_size = ctx->logits.size();
 
-        data_ctx->write(&logits_cap,  sizeof(logits_cap));
         data_ctx->write(&logits_size, sizeof(logits_size));
 
         if (logits_size) {
             data_ctx->write(ctx->logits.data(), logits_size * sizeof(float));
         }
-
-        // If there is a gap between the size and the capacity, write padding
-        size_t padding_size = (logits_cap - logits_size) * sizeof(float);
-        if (padding_size > 0) {
-            std::vector<uint8_t> padding(padding_size, 0); // Create a buffer filled with zeros
-            data_ctx->write(padding.data(), padding_size);
-        }
     }
 
     // copy embeddings
@@ -9925,13 +9910,13 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
     // set rng
     {
         size_t rng_size;
-        char   rng_buf[LLAMA_MAX_RNG_STATE];
+        memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size);
 
-        memcpy(&rng_size,   inp, sizeof(rng_size));    inp += sizeof(rng_size);
-        memcpy(&rng_buf[0], inp, LLAMA_MAX_RNG_STATE); inp += LLAMA_MAX_RNG_STATE;
+        GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
 
-        std::stringstream rng_ss;
-        rng_ss.str(std::string(&rng_buf[0], rng_size));
+        std::string rng_str((char *)inp, rng_size); inp += rng_size;
+
+        std::istringstream rng_ss(rng_str);
         rng_ss >> ctx->rng;
 
         GGML_ASSERT(!rng_ss.fail());
@@ -9939,20 +9924,18 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
 
     // set logits
     {
-        size_t logits_cap;
         size_t logits_size;
 
-        memcpy(&logits_cap,  inp, sizeof(logits_cap));  inp += sizeof(logits_cap);
         memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size);
 
-        GGML_ASSERT(ctx->logits.capacity() == logits_cap);
+        GGML_ASSERT(ctx->logits.capacity() >= logits_size);
 
         if (logits_size) {
             ctx->logits.resize(logits_size);
+
             memcpy(ctx->logits.data(), inp, logits_size * sizeof(float));
+            inp += logits_size * sizeof(float);
         }
-
-        inp += logits_cap * sizeof(float);
     }
 
     // set embeddings
diff --git a/llama.h b/llama.h
index 689e12d7ce0921504f630bccc3ae197e961deb11..01d6fafaa4b0d2fecc4a60d673ddbde68f738ad5 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -43,7 +43,7 @@
 #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
 
 #define LLAMA_SESSION_MAGIC   LLAMA_FILE_MAGIC_GGSN
-#define LLAMA_SESSION_VERSION 3
+#define LLAMA_SESSION_VERSION 4
 
 #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
 // Defined when llama.cpp is compiled with support for offloading model layers to GPU.