]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : only copy used KV cache in get / set state (#890)
authorLuis Herrera <redacted>
Mon, 8 May 2023 17:59:21 +0000 (12:59 -0500)
committerGitHub <redacted>
Mon, 8 May 2023 17:59:21 +0000 (20:59 +0300)
---------

Co-authored-by: ejones <redacted>
examples/talk-llama/llama.cpp
examples/talk-llama/llama.h

index d92b1d9d33344f4ed35011b5b2816721b6ff09fd..dfc68ed60604dfa424e014cecf5204af3d3fedc1 100644 (file)
@@ -1270,6 +1270,9 @@ static bool llama_eval_internal(
     //embd_w.resize(n_vocab*N);
     //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
 
+    // update kv token count
+    lctx.model.kv_self.n = n_past + N;
+
     // extract logits
     {
         auto & logits_out = lctx.logits;
@@ -2386,7 +2389,7 @@ void llama_set_rng_seed(struct llama_context * ctx, int seed) {
     ctx->rng.seed(seed);
 }
 
-// Returns the size of the state
+// Returns the *maximum* size of the state
 size_t llama_get_state_size(struct llama_context * ctx) {
     // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
     // for reference, std::mt19937(1337) serializes to 6701 bytes.
@@ -2465,21 +2468,51 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
 
     // copy kv cache
     {
-        const size_t kv_size = ctx->model.kv_self.buf.size;
+        const auto & kv_self = ctx->model.kv_self;
+        const auto & hparams = ctx->model.hparams;
+        const int    n_layer = hparams.n_layer;
+        const int    n_embd  = hparams.n_embd;
+        const int    n_ctx   = hparams.n_ctx;
+
+        const size_t kv_size = kv_self.buf.size;
         const int    kv_ntok = llama_get_kv_cache_token_count(ctx);
 
         memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
         memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
 
         if (kv_size) {
-            memcpy(out, ctx->model.kv_self.buf.addr, kv_size); out += kv_size;
+            const size_t elt_size = ggml_element_size(kv_self.k);
+            char buffer[4096];
+            ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
+            ggml_cgraph gf{};
+            gf.n_threads = 1;
+
+            ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
+            kout3d->data = out;
+            out += ggml_nbytes(kout3d);
+
+            ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
+            vout3d->data = out;
+            out += ggml_nbytes(vout3d);
+
+            ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
+                n_embd, kv_ntok, n_layer,
+                elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
+
+            ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
+                kv_ntok, n_embd, n_layer,
+                elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
+
+            ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
+            ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
+            ggml_graph_compute(cpy_ctx, &gf);
         }
     }
 
     const size_t written  = out - dest;
-    const size_t expected = llama_get_state_size(ctx);
+    const size_t max_size = llama_get_state_size(ctx);
 
-    LLAMA_ASSERT(written == expected);
+    LLAMA_ASSERT(written <= max_size);
 
     return written;
 }
@@ -2537,6 +2570,12 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
 
     // set kv cache
     {
+        const auto & kv_self = ctx->model.kv_self;
+        const auto & hparams = ctx->model.hparams;
+        const int    n_layer = hparams.n_layer;
+        const int    n_embd  = hparams.n_embd;
+        const int    n_ctx   = hparams.n_ctx;
+
         size_t kv_size;
         int kv_ntok;
 
@@ -2544,15 +2583,33 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
         memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
 
         if (kv_size) {
-            LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);
+            LLAMA_ASSERT(kv_self.buf.size == kv_size);
+
+            const size_t elt_size = ggml_element_size(kv_self.k);
+            char buffer[4096];
+            ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
+            ggml_cgraph gf{};
+            gf.n_threads = 1;
+
+            ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
+            kin3d->data = (void *) in;
+            in += ggml_nbytes(kin3d);
+
+            ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
+            vin3d->data = (void *) in;
+            in += ggml_nbytes(vin3d);
 
-            void * k_data = ctx->model.kv_self.k->data; // remember data pointers
-            void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
+            ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
+                n_embd, kv_ntok, n_layer,
+                elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
 
-            memcpy(ctx->model.kv_self.buf.addr, in, kv_size); in += kv_size;
+            ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
+                kv_ntok, n_embd, n_layer,
+                elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
 
-            ctx->model.kv_self.k->data = k_data; // restore correct data pointers
-            ctx->model.kv_self.v->data = v_data;
+            ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
+            ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d));
+            ggml_graph_compute(cpy_ctx, &gf);
 
         }
 
@@ -2560,9 +2617,9 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
     }
 
     const size_t nread    = in - src;
-    const size_t expected = llama_get_state_size(ctx);
+    const size_t max_size = llama_get_state_size(ctx);
 
-    LLAMA_ASSERT(nread == expected);
+    LLAMA_ASSERT(nread <= max_size);
 
     return nread;
 }
@@ -2733,14 +2790,14 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi
     // restore the context state
     {
         const size_t n_state_size_cur = file.size - file.tell();
-        const size_t n_state_size_exp = llama_get_state_size(ctx);
+        const size_t n_state_size_max = llama_get_state_size(ctx);
 
-        if (n_state_size_cur != n_state_size_exp) {
-            fprintf(stderr, "%s : the state size in session file didn't match! expected %zu, got %zu\n", __func__, n_state_size_exp, n_state_size_cur);
+       if (n_state_size_cur > n_state_size_max) {
+            fprintf(stderr, "%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur);
             return false;
         }
 
-        std::vector<uint8_t> state_data(n_state_size_cur);
+        std::vector<uint8_t> state_data(n_state_size_max);
         file.read_raw(state_data.data(), n_state_size_cur);
 
         llama_set_state_data(ctx, state_data.data());
@@ -2763,12 +2820,12 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
 
     // save the context state
     {
-        const size_t n_state_size = llama_get_state_size(ctx);
+        const size_t n_state_size_max = llama_get_state_size(ctx);
 
-        std::vector<uint8_t> state_data(n_state_size);
-        llama_copy_state_data(ctx, state_data.data());
+        std::vector<uint8_t> state_data(n_state_size_max);
+        const size_t n_state_size_cur = llama_copy_state_data(ctx, state_data.data());
 
-        file.write_raw(state_data.data(), n_state_size);
+        file.write_raw(state_data.data(), n_state_size_cur);
     }
 
     return true;
index 40d152edd26fda470a2e015d92ec91ffea65056a..2f12090b56953451f0c8a75652dce426529275bc 100644 (file)
@@ -23,7 +23,7 @@
 #define LLAMA_FILE_MAGIC             'ggjt'
 #define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml'
 #define LLAMA_SESSION_MAGIC          'ggsn'
-#define LLAMA_SESSION_VERSION        0
+#define LLAMA_SESSION_VERSION        1
 
 #ifdef __cplusplus
 extern "C" {
@@ -127,7 +127,8 @@ extern "C" {
     // Sets the current rng seed.
     LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
 
-    // Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
+    // Returns the maximum size in bytes of the state (rng, logits, embedding
+    // and kv_cache) - will often be smaller after compacting tokens
     LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);
 
     // Copies the state to the specified destination address.