]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Stream save llama context data to file instead of allocating entire buffer upfront...
authorl3utterfly <redacted>
Fri, 4 Aug 2023 11:29:52 +0000 (19:29 +0800)
committerGitHub <redacted>
Fri, 4 Aug 2023 11:29:52 +0000 (13:29 +0200)
* added stream saving context data to file to avoid allocating unnecessary amounts of memory

* generalised copying state data to file or buffer

* added comments explaining how copy_state_data works

* fixed trailing whitespaces

* fixed save load state example

* updated save load state to use public function in llama.cpp

* - restored breakage of the llama_copy_state_data API
- moved new logic for copying llama state data to internal function

* fixed function declaration order

* restored save load state example

* fixed whitepace

* removed unused llama-util.h include

* Apply suggestions from code review

Co-authored-by: slaren <redacted>
* Apply code review suggestions

Co-authored-by: slaren <redacted>
---------

Co-authored-by: slaren <redacted>
llama-util.h
llama.cpp

index 042ebe43c48e1346d747bce23ec7db064c9d45a2..3fc03ce28273ed4ed2a09647160152c68e31ea8b 100644 (file)
@@ -149,6 +149,46 @@ struct llama_file {
     }
 };
 
+// llama_context_data
+struct llama_data_context {
+    virtual void write(const void * src, size_t size) = 0;
+    virtual size_t get_size_written() = 0;
+    virtual ~llama_data_context() = default;
+};
+
+struct llama_data_buffer_context : llama_data_context {
+    uint8_t* ptr;
+    size_t size_written = 0;
+
+    llama_data_buffer_context(uint8_t * p) : ptr(p) {}
+
+    void write(const void * src, size_t size) override {
+        memcpy(ptr, src, size);
+        ptr += size;
+        size_written += size;
+    }
+
+    size_t get_size_written() override {
+        return size_written;
+    }
+};
+
+struct llama_data_file_context : llama_data_context {
+    llama_file* file;
+    size_t size_written = 0;
+
+    llama_data_file_context(llama_file * f) : file(f) {}
+
+    void write(const void * src, size_t size) override {
+        file->write_raw(src, size);
+        size_written += size;
+    }
+
+    size_t get_size_written() override {
+        return size_written;
+    }
+};
+
 #if defined(_WIN32)
 static std::string llama_format_win_err(DWORD err) {
     LPSTR buf;
index d427054dd8d6ae585490fdc63ddb5cd2299c68fa..839739870eb3e97a72fd56e68f4377a3639ac4b9 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -3743,10 +3743,20 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
     return s_total;
 }
 
-// Copies the state to the specified destination address
-size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
-    uint8_t * out = dst;
-
+/** copy state data into either a buffer or file depending on the passed in context
+ *
+ * file context:
+ * llama_file file("/path", "wb");
+ * llama_data_file_context data_ctx(&file);
+ * llama_copy_state_data(ctx, &data_ctx);
+ *
+ * buffer context:
+ * std::vector<uint8_t> buf(max_size, 0);
+ * llama_data_buffer_context data_ctx(&buf.data());
+ * llama_copy_state_data(ctx, &data_ctx);
+ *
+*/
+void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
     // copy rng
     {
         std::stringstream rng_ss;
@@ -3758,8 +3768,8 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
         memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE);
         memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
 
-        memcpy(out, &rng_size,   sizeof(rng_size));    out += sizeof(rng_size);
-        memcpy(out, &rng_buf[0], LLAMA_MAX_RNG_STATE); out += LLAMA_MAX_RNG_STATE;
+        data_ctx->write(&rng_size,   sizeof(rng_size));
+        data_ctx->write(&rng_buf[0], LLAMA_MAX_RNG_STATE);
     }
 
     // copy logits
@@ -3767,25 +3777,29 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
         const size_t logits_cap  = ctx->logits.capacity();
         const size_t logits_size = ctx->logits.size();
 
-        memcpy(out, &logits_cap,  sizeof(logits_cap));  out += sizeof(logits_cap);
-        memcpy(out, &logits_size, sizeof(logits_size)); out += sizeof(logits_size);
+        data_ctx->write(&logits_cap,  sizeof(logits_cap));
+        data_ctx->write(&logits_size, sizeof(logits_size));
 
         if (logits_size) {
-            memcpy(out, ctx->logits.data(), logits_size * sizeof(float));
+            data_ctx->write(ctx->logits.data(), logits_size * sizeof(float));
         }
 
-        out += logits_cap * sizeof(float);
+        // If there is a gap between the size and the capacity, write padding
+        size_t padding_size = (logits_cap - logits_size) * sizeof(float);
+        if (padding_size > 0) {
+            std::vector<uint8_t> padding(padding_size, 0); // Create a buffer filled with zeros
+            data_ctx->write(padding.data(), padding_size);
+        }
     }
 
     // copy embeddings
     {
         const size_t embedding_size = ctx->embedding.size();
 
-        memcpy(out, &embedding_size, sizeof(embedding_size)); out += sizeof(embedding_size);
+        data_ctx->write(&embedding_size, sizeof(embedding_size));
 
         if (embedding_size) {
-            memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float));
-            out += embedding_size * sizeof(float);
+            data_ctx->write(ctx->embedding.data(), embedding_size * sizeof(float));
         }
     }
 
@@ -3800,8 +3814,8 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
         const size_t kv_size = kv_self.buf.size;
         const int    kv_ntok = llama_get_kv_cache_token_count(ctx);
 
-        memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
-        memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
+        data_ctx->write(&kv_size, sizeof(kv_size));
+        data_ctx->write(&kv_ntok, sizeof(kv_ntok));
 
         if (kv_size) {
             const size_t elt_size = ggml_element_size(kv_self.k);
@@ -3810,12 +3824,12 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
             ggml_cgraph gf{};
 
             ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
-            kout3d->data = out;
-            out += ggml_nbytes(kout3d);
+            std::vector<uint8_t> kout3d_data(ggml_nbytes(kout3d), 0);
+            kout3d->data = kout3d_data.data();
 
             ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
-            vout3d->data = out;
-            out += ggml_nbytes(vout3d);
+            std::vector<uint8_t> vout3d_data(ggml_nbytes(vout3d), 0);
+            vout3d->data = vout3d_data.data();
 
             ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
                 n_embd, kv_ntok, n_layer,
@@ -3830,15 +3844,20 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
             ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1);
 
             ggml_free(cpy_ctx);
+
+            // our data is now in the kout3d_data and vout3d_data buffers
+            // write them to file
+            data_ctx->write(kout3d_data.data(), kout3d_data.size());
+            data_ctx->write(vout3d_data.data(), vout3d_data.size());
         }
     }
+}
 
-    const size_t written  = out - dst;
-    const size_t max_size = llama_get_state_size(ctx);
-
-    LLAMA_ASSERT(written <= max_size);
+size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
+    llama_data_buffer_context data_ctx(dst);
+    llama_copy_state_data_internal(ctx, &data_ctx);
 
-    return written;
+    return data_ctx.get_size_written();
 }
 
 // Sets the state reading from the specified source address
@@ -4023,15 +4042,9 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
     file.write_u32((uint32_t) n_token_count);
     file.write_raw(tokens, sizeof(llama_token) * n_token_count);
 
-    // save the context state
-    {
-        const size_t n_state_size_max = llama_get_state_size(ctx);
-
-        std::vector<uint8_t> state_data(n_state_size_max);
-        const size_t n_state_size_cur = llama_copy_state_data(ctx, state_data.data());
-
-        file.write_raw(state_data.data(), n_state_size_cur);
-    }
+    // save the context state using stream saving
+    llama_data_file_context data_ctx(&file);
+    llama_copy_state_data_internal(ctx, &data_ctx);
 
     return true;
 }