]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : refactor get / set state + remove redundant kv cache API (#1143)
authorGeorgi Gerganov <redacted>
Mon, 24 Apr 2023 04:40:02 +0000 (07:40 +0300)
committerGitHub <redacted>
Mon, 24 Apr 2023 04:40:02 +0000 (07:40 +0300)
llama.cpp
llama.h

index 8c1d65778be8bb844588cd4e54761dc4bd121cd3..bc0ef1281e37955b98aa643c7ce6f727f24f1270 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -2072,35 +2072,191 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
     }
 }
 
-// Returns the KV cache that will contain the context for the
-// ongoing prediction with the model.
-const uint8_t * llama_get_kv_cache(struct llama_context * ctx) {
-    return ctx->model.kv_self.buf.addr;
+int llama_get_kv_cache_token_count(struct llama_context * ctx) {
+    return ctx->model.kv_self.n;
 }
 
-// Returns the size of the KV cache
-size_t llama_get_kv_cache_size(struct llama_context * ctx) {
-    return ctx->model.kv_self.buf.size;
+#define LLAMA_MAX_RNG_STATE 64*1024
+
+// Returns the size of the state
+size_t llama_get_state_size(struct llama_context * ctx) {
+    // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
+    // for reference, std::mt19937(1337) serializes to 6701 bytes.
+    const size_t s_rng_size        = sizeof(size_t);
+    const size_t s_rng             = LLAMA_MAX_RNG_STATE;
+    const size_t s_logits_capacity = sizeof(size_t);
+    const size_t s_logits_size     = sizeof(size_t);
+    const size_t s_logits          = ctx->logits.capacity() * sizeof(float);
+    const size_t s_embedding_size  = sizeof(size_t);
+    const size_t s_embedding       = ctx->embedding.size() * sizeof(float);
+    const size_t s_kv_size         = sizeof(size_t);
+    const size_t s_kv_ntok         = sizeof(int);
+    const size_t s_kv              = ctx->model.kv_self.buf.size;
+
+    const size_t s_total = (
+        + s_rng_size
+        + s_rng
+        + s_logits_capacity
+        + s_logits_size
+        + s_logits
+        + s_embedding_size
+        + s_embedding
+        + s_kv_size
+        + s_kv_ntok
+        + s_kv
+    );
+
+    return s_total;
 }
 
-int llama_get_kv_cache_token_count(struct llama_context * ctx) {
-    return ctx->model.kv_self.n;
+// Copies the state to the specified destination address
+size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
+    uint8_t * out = dest;
+
+    // copy rng
+    {
+        std::stringstream rng_ss;
+        rng_ss << ctx->rng;
+
+        const size_t rng_size = rng_ss.str().size();
+        char rng_buf[LLAMA_MAX_RNG_STATE];
+
+        memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE);
+        memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
+
+        memcpy(out, &rng_size,   sizeof(rng_size));    out += sizeof(rng_size);
+        memcpy(out, &rng_buf[0], LLAMA_MAX_RNG_STATE); out += LLAMA_MAX_RNG_STATE;
+    }
+
+    // copy logits
+    {
+        const size_t logits_cap  = ctx->logits.capacity();
+        const size_t logits_size = ctx->logits.size();
+
+        memcpy(out, &logits_cap,  sizeof(logits_cap));  out += sizeof(logits_cap);
+        memcpy(out, &logits_size, sizeof(logits_size)); out += sizeof(logits_size);
+
+        if (logits_size) {
+            memcpy(out, ctx->logits.data(), logits_size * sizeof(float));
+        }
+
+        out += logits_cap * sizeof(float);
+    }
+
+    // copy embeddings
+    {
+        const size_t embedding_size = ctx->embedding.size();
+
+        memcpy(out, &embedding_size, sizeof(embedding_size)); out += sizeof(embedding_size);
+
+        if (embedding_size) {
+            memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float));
+            out += embedding_size * sizeof(float);
+        }
+    }
+
+    // copy kv cache
+    {
+        const size_t kv_size = ctx->model.kv_self.buf.size;
+        const int    kv_ntok = llama_get_kv_cache_token_count(ctx);
+
+        memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
+        memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
+
+        if (kv_size) {
+            memcpy(out, ctx->model.kv_self.buf.addr, kv_size); out += kv_size;
+        }
+    }
+
+    const size_t written  = out - dest;
+    const size_t expected = llama_get_state_size(ctx);
+
+    LLAMA_ASSERT(written == expected);
+
+    return written;
 }
 
-// Sets the KV cache containing the current context for the model
-void llama_set_kv_cache(
-        struct llama_context * ctx,
-               const uint8_t * kv_cache,
-                      size_t   n_size,
-                         int   n_token_count) {
-    // Make sure we have the same kv cache setup
-    LLAMA_ASSERT(ctx->model.kv_self.buf.size == n_size);
-    void * k_data = ctx->model.kv_self.k->data; // remember data pointers
-    void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
-    memcpy(ctx->model.kv_self.buf.addr, kv_cache, n_size);
-    ctx->model.kv_self.k->data = k_data; // restore correct data pointers
-    ctx->model.kv_self.v->data = v_data;
-    ctx->model.kv_self.n = n_token_count;
+// Sets the state reading from the specified source address
+size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
+    const uint8_t * in = src;
+
+    // set rng
+    {
+        size_t rng_size;
+        char   rng_buf[LLAMA_MAX_RNG_STATE];
+
+        memcpy(&rng_size,   in, sizeof(rng_size));    in += sizeof(rng_size);
+        memcpy(&rng_buf[0], in, LLAMA_MAX_RNG_STATE); in += LLAMA_MAX_RNG_STATE;
+
+        std::stringstream rng_ss;
+        rng_ss.str(std::string(&rng_buf[0], rng_size));
+        rng_ss >> ctx->rng;
+
+        LLAMA_ASSERT(rng_ss.fail() == false);
+    }
+
+    // set logits
+    {
+        size_t logits_cap;
+        size_t logits_size;
+
+        memcpy(&logits_cap,  in, sizeof(logits_cap));  in += sizeof(logits_cap);
+        memcpy(&logits_size, in, sizeof(logits_size)); in += sizeof(logits_size);
+
+        LLAMA_ASSERT(ctx->logits.capacity() == logits_cap);
+
+        if (logits_size) {
+            ctx->logits.resize(logits_size);
+            memcpy(ctx->logits.data(), in, logits_size * sizeof(float));
+        }
+
+        in += logits_cap * sizeof(float);
+    }
+
+    // set embeddings
+    {
+        size_t embedding_size;
+
+        memcpy(&embedding_size, in, sizeof(embedding_size)); in += sizeof(embedding_size);
+
+        LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);
+
+        if (embedding_size) {
+            memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float));
+            in += embedding_size * sizeof(float);
+        }
+    }
+
+    // set kv cache
+    {
+        size_t kv_size;
+        int kv_ntok;
+
+        memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size);
+        memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
+
+        if (kv_size) {
+            LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);
+
+            void * k_data = ctx->model.kv_self.k->data; // remember data pointers
+            void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
+
+            memcpy(ctx->model.kv_self.buf.addr, in, kv_size); in += kv_size;
+
+            ctx->model.kv_self.k->data = k_data; // restore correct data pointers
+            ctx->model.kv_self.v->data = v_data;
+
+        }
+
+        ctx->model.kv_self.n = kv_ntok;
+    }
+
+    const size_t nread    = in - src;
+    const size_t expected = llama_get_state_size(ctx);
+
+    LLAMA_ASSERT(nread == expected);
+
+    return nread;
 }
 
 int llama_eval(
@@ -2256,120 +2412,3 @@ std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_te
     return ctx->model.tensors_by_name;
 }
 
-// Returns the size of the state
-size_t llama_get_state_size(struct llama_context * ctx) {
-    // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
-    // for reference, std::mt19937(1337) serializes to 6701 bytes.
-    const size_t s_rng_size = sizeof(size_t);
-    const size_t s_rng = 64*1024;
-    const size_t s_logits_capacity = sizeof(size_t);
-    const size_t s_logits_size = sizeof(size_t);
-    const size_t s_logits = ctx->logits.capacity() * sizeof(float);
-    const size_t s_embedding_size = sizeof(size_t);
-    const size_t s_embedding = ctx->embedding.size() * sizeof(float);
-    const size_t s_kv_size = sizeof(size_t);
-    const size_t s_kv_ntok = sizeof(int);
-    const size_t s_kv = llama_get_kv_cache_size(ctx);
-    const size_t s_total = (
-        + s_rng_size
-        + s_rng
-        + s_logits_capacity
-        + s_logits_size
-        + s_logits
-        + s_embedding_size
-        + s_embedding
-        + s_kv_size
-        + s_kv_ntok
-        + s_kv
-    );
-    return s_total;
-}
-
-// Copies the state to the specified destination address
-size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
-    std::stringstream rng_ss;
-    rng_ss << ctx->rng;
-    const size_t rng_size = rng_ss.str().size();
-    char rng_buf[64*1024];
-    memset(&rng_buf[0], 0, 64*1024);
-    memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
-    const size_t logits_capacity = ctx->logits.capacity();
-    const size_t logits_size = ctx->logits.size();
-    const size_t embedding_size = ctx->embedding.size();
-    const size_t kv_size = llama_get_kv_cache_size(ctx);
-    const int kv_ntok = llama_get_kv_cache_token_count(ctx);
-
-    uint8_t * out = dest;
-    memcpy(out, &rng_size, sizeof(size_t)); out += sizeof(size_t);
-    memcpy(out, &rng_buf[0], 64*1024); out += 64*1024;
-    memcpy(out, &logits_capacity, sizeof(size_t)); out += sizeof(size_t);
-    memcpy(out, &logits_size, sizeof(size_t)); out += sizeof(size_t);
-    if (logits_size) {
-        memcpy(out, ctx->logits.data(), logits_size * sizeof(float));
-    }
-    out += logits_capacity * sizeof(float);
-    memcpy(out, &embedding_size, sizeof(size_t)); out += sizeof(size_t);
-    if (embedding_size) {
-        memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); out += embedding_size * sizeof(float);
-    }
-    memcpy(out, &kv_size, sizeof(size_t)); out += sizeof(size_t);
-    memcpy(out, &kv_ntok, sizeof(int)); out += sizeof(int);
-    if (kv_size) {
-        memcpy(out, llama_get_kv_cache(ctx), kv_size); out += kv_size;
-    }
-    const size_t written = out - dest;
-    const size_t expected = llama_get_state_size(ctx);
-    LLAMA_ASSERT(written == expected);
-    return written;
-}
-
-// Sets the state reading from the specified source address
-size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
-    size_t rng_size;
-    char rng_buf[64*1024];
-    std::stringstream rng_ss;
-
-    const uint8_t * in = src;
-    memcpy(&rng_size, in, sizeof(size_t)); in += sizeof(size_t);
-    memcpy(&rng_buf[0], in, 64*1024); in += 64*1024;
-    rng_ss.str(std::string(&rng_buf[0], rng_size));
-    rng_ss >> ctx->rng;
-    LLAMA_ASSERT(rng_ss.fail() == false);
-
-    size_t logits_capacity;
-    size_t logits_size;
-    size_t embedding_size;
-    size_t kv_size;
-    int kv_ntok;
-
-    memcpy(&logits_capacity, in, sizeof(size_t)); in += sizeof(size_t);
-    memcpy(&logits_size, in, sizeof(size_t)); in += sizeof(size_t);
-    LLAMA_ASSERT(ctx->logits.capacity() == logits_capacity);
-    if (logits_size) {
-        ctx->logits.resize(logits_size);
-        memcpy(ctx->logits.data(), in, logits_size * sizeof(float));
-    }
-    in += logits_capacity * sizeof(float);
-    memcpy(&embedding_size, in, sizeof(size_t)); in += sizeof(size_t);
-    LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);
-    if (embedding_size) {
-        memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float));
-        in += embedding_size * sizeof(float);
-    }
-    memcpy(&kv_size, in, sizeof(size_t)); in += sizeof(size_t);
-    memcpy(&kv_ntok, in, sizeof(int)); in += sizeof(int);
-    if (kv_size) {
-        LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);
-        void * k_data = ctx->model.kv_self.k->data; // remember data pointers
-        void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
-        memcpy(ctx->model.kv_self.buf.addr, in, kv_size);
-        ctx->model.kv_self.k->data = k_data; // restore correct data pointers
-        ctx->model.kv_self.v->data = v_data;
-        in += kv_size;
-    }
-    ctx->model.kv_self.n = kv_ntok;
-    const size_t nread = in - src;
-    const size_t expected = llama_get_state_size(ctx);
-    LLAMA_ASSERT(nread == expected);
-    return nread;
-}
diff --git a/llama.h b/llama.h
index f68a0cb403b21c315f8fe816454fe1b00d935b79..e9e3abea597ebe1ab6d5fe731e8eb9bbcd81fa4a 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -112,23 +112,9 @@ extern "C" {
                       const char * path_base_model,
                              int   n_threads);
 
-    // Returns the KV cache that will contain the context for the
-    // ongoing prediction with the model.
-    LLAMA_API const uint8_t * llama_get_kv_cache(struct llama_context * ctx);
-
-    // Returns the size of the KV cache
-    LLAMA_API size_t llama_get_kv_cache_size(struct llama_context * ctx);
-
     // Returns the number of tokens in the KV cache
     LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx);
 
-    // Sets the KV cache containing the current context for the model
-    LLAMA_API void llama_set_kv_cache(
-            struct llama_context * ctx,
-                   const uint8_t * kv_cache,
-                          size_t   n_size,
-                             int   n_token_count);
-
     // Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
     LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);