]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : free ggml context in set / copy state data (close #1425)
authorGeorgi Gerganov <redacted>
Sat, 13 May 2023 06:08:52 +0000 (09:08 +0300)
committerGeorgi Gerganov <redacted>
Sat, 13 May 2023 06:08:52 +0000 (09:08 +0300)
llama.cpp
llama.h

index 0a47faa9d738d6db3a1426e650086d66ceb5e4fb..f52671b67c63655dc3c49f27df541b368bd2c5ea 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -2450,8 +2450,8 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
 }
 
 // Copies the state to the specified destination address
-size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
-    uint8_t * out = dest;
+size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
+    uint8_t * out = dst;
 
     // copy rng
     {
@@ -2511,7 +2511,9 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
 
         if (kv_size) {
             const size_t elt_size = ggml_element_size(kv_self.k);
+
             char buffer[4096];
+
             ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
             ggml_cgraph gf{};
             gf.n_threads = 1;
@@ -2535,10 +2537,12 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
             ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
             ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
             ggml_graph_compute(cpy_ctx, &gf);
+
+            ggml_free(cpy_ctx);
         }
     }
 
-    const size_t written  = out - dest;
+    const size_t written  = out - dst;
     const size_t max_size = llama_get_state_size(ctx);
 
     LLAMA_ASSERT(written <= max_size);
@@ -2548,15 +2552,15 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
 
 // Sets the state reading from the specified source address
 size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
-    const uint8_t * in = src;
+    const uint8_t * inp = src;
 
     // set rng
     {
         size_t rng_size;
         char   rng_buf[LLAMA_MAX_RNG_STATE];
 
-        memcpy(&rng_size,   in, sizeof(rng_size));    in += sizeof(rng_size);
-        memcpy(&rng_buf[0], in, LLAMA_MAX_RNG_STATE); in += LLAMA_MAX_RNG_STATE;
+        memcpy(&rng_size,   inp, sizeof(rng_size));    inp += sizeof(rng_size);
+        memcpy(&rng_buf[0], inp, LLAMA_MAX_RNG_STATE); inp += LLAMA_MAX_RNG_STATE;
 
         std::stringstream rng_ss;
         rng_ss.str(std::string(&rng_buf[0], rng_size));
@@ -2570,30 +2574,30 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
         size_t logits_cap;
         size_t logits_size;
 
-        memcpy(&logits_cap,  in, sizeof(logits_cap));  in += sizeof(logits_cap);
-        memcpy(&logits_size, in, sizeof(logits_size)); in += sizeof(logits_size);
+        memcpy(&logits_cap,  inp, sizeof(logits_cap));  inp += sizeof(logits_cap);
+        memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size);
 
         LLAMA_ASSERT(ctx->logits.capacity() == logits_cap);
 
         if (logits_size) {
             ctx->logits.resize(logits_size);
-            memcpy(ctx->logits.data(), in, logits_size * sizeof(float));
+            memcpy(ctx->logits.data(), inp, logits_size * sizeof(float));
         }
 
-        in += logits_cap * sizeof(float);
+        inp += logits_cap * sizeof(float);
     }
 
     // set embeddings
     {
         size_t embedding_size;
 
-        memcpy(&embedding_size, in, sizeof(embedding_size)); in += sizeof(embedding_size);
+        memcpy(&embedding_size, inp, sizeof(embedding_size)); inp += sizeof(embedding_size);
 
         LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);
 
         if (embedding_size) {
-            memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float));
-            in += embedding_size * sizeof(float);
+            memcpy(ctx->embedding.data(), inp, embedding_size * sizeof(float));
+            inp += embedding_size * sizeof(float);
         }
     }
 
@@ -2608,25 +2612,27 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
         size_t kv_size;
         int kv_ntok;
 
-        memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size);
-        memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
+        memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
+        memcpy(&kv_ntok, inp, sizeof(kv_ntok)); inp += sizeof(kv_ntok);
 
         if (kv_size) {
             LLAMA_ASSERT(kv_self.buf.size == kv_size);
 
             const size_t elt_size = ggml_element_size(kv_self.k);
+
             char buffer[4096];
+
             ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
             ggml_cgraph gf{};
             gf.n_threads = 1;
 
             ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
-            kin3d->data = (void *) in;
-            in += ggml_nbytes(kin3d);
+            kin3d->data = (void *) inp;
+            inp += ggml_nbytes(kin3d);
 
             ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
-            vin3d->data = (void *) in;
-            in += ggml_nbytes(vin3d);
+            vin3d->data = (void *) inp;
+            inp += ggml_nbytes(vin3d);
 
             ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
                 n_embd, kv_ntok, n_layer,
@@ -2639,12 +2645,14 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
             ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
             ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d));
             ggml_graph_compute(cpy_ctx, &gf);
+
+            ggml_free(cpy_ctx);
         }
 
         ctx->model.kv_self.n = kv_ntok;
     }
 
-    const size_t nread    = in - src;
+    const size_t nread    = inp - src;
     const size_t max_size = llama_get_state_size(ctx);
 
     LLAMA_ASSERT(nread <= max_size);
diff --git a/llama.h b/llama.h
index 1a65cd58923896b0807a6ec5379787157281d803..ca05645b974dedd6f8125cdeebdc521103253176 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -134,7 +134,7 @@ extern "C" {
     // Copies the state to the specified destination address.
     // Destination needs to have allocated enough memory.
     // Returns the number of bytes copied
-    LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest);
+    LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst);
 
     // Set the state reading from the specified address
     // Returns the number of bytes read