}
}
-// Returns the KV cache that will contain the context for the
-// ongoing prediction with the model.
-const uint8_t * llama_get_kv_cache(struct llama_context * ctx) {
- return ctx->model.kv_self.buf.addr;
+int llama_get_kv_cache_token_count(struct llama_context * ctx) {
+ return ctx->model.kv_self.n;
}
-// Returns the size of the KV cache
-size_t llama_get_kv_cache_size(struct llama_context * ctx) {
- return ctx->model.kv_self.buf.size;
+#define LLAMA_MAX_RNG_STATE 64*1024
+
+// Returns the size of the state
+size_t llama_get_state_size(struct llama_context * ctx) {
+ // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
+ // for reference, std::mt19937(1337) serializes to 6701 bytes.
+ const size_t s_rng_size = sizeof(size_t);
+ const size_t s_rng = LLAMA_MAX_RNG_STATE;
+ const size_t s_logits_capacity = sizeof(size_t);
+ const size_t s_logits_size = sizeof(size_t);
+ const size_t s_logits = ctx->logits.capacity() * sizeof(float);
+ const size_t s_embedding_size = sizeof(size_t);
+ const size_t s_embedding = ctx->embedding.size() * sizeof(float);
+ const size_t s_kv_size = sizeof(size_t);
+ const size_t s_kv_ntok = sizeof(int);
+ const size_t s_kv = ctx->model.kv_self.buf.size;
+
+ const size_t s_total = (
+ + s_rng_size
+ + s_rng
+ + s_logits_capacity
+ + s_logits_size
+ + s_logits
+ + s_embedding_size
+ + s_embedding
+ + s_kv_size
+ + s_kv_ntok
+ + s_kv
+ );
+
+ return s_total;
}
-int llama_get_kv_cache_token_count(struct llama_context * ctx) {
- return ctx->model.kv_self.n;
+// Copies the state to the specified destination address
+size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
+ uint8_t * out = dest;
+
+ // copy rng
+ {
+ std::stringstream rng_ss;
+ rng_ss << ctx->rng;
+
+ const size_t rng_size = rng_ss.str().size();
+ char rng_buf[LLAMA_MAX_RNG_STATE];
+
+ memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE);
+ memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
+
+ memcpy(out, &rng_size, sizeof(rng_size)); out += sizeof(rng_size);
+ memcpy(out, &rng_buf[0], LLAMA_MAX_RNG_STATE); out += LLAMA_MAX_RNG_STATE;
+ }
+
+ // copy logits
+ {
+ const size_t logits_cap = ctx->logits.capacity();
+ const size_t logits_size = ctx->logits.size();
+
+ memcpy(out, &logits_cap, sizeof(logits_cap)); out += sizeof(logits_cap);
+ memcpy(out, &logits_size, sizeof(logits_size)); out += sizeof(logits_size);
+
+ if (logits_size) {
+ memcpy(out, ctx->logits.data(), logits_size * sizeof(float));
+ }
+
+ out += logits_cap * sizeof(float);
+ }
+
+ // copy embeddings
+ {
+ const size_t embedding_size = ctx->embedding.size();
+
+ memcpy(out, &embedding_size, sizeof(embedding_size)); out += sizeof(embedding_size);
+
+ if (embedding_size) {
+ memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float));
+ out += embedding_size * sizeof(float);
+ }
+ }
+
+ // copy kv cache
+ {
+ const size_t kv_size = ctx->model.kv_self.buf.size;
+ const int kv_ntok = llama_get_kv_cache_token_count(ctx);
+
+ memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
+ memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
+
+ if (kv_size) {
+ memcpy(out, ctx->model.kv_self.buf.addr, kv_size); out += kv_size;
+ }
+ }
+
+ const size_t written = out - dest;
+ const size_t expected = llama_get_state_size(ctx);
+
+ LLAMA_ASSERT(written == expected);
+
+ return written;
}
-// Sets the KV cache containing the current context for the model
-void llama_set_kv_cache(
- struct llama_context * ctx,
- const uint8_t * kv_cache,
- size_t n_size,
- int n_token_count) {
- // Make sure we have the same kv cache setup
- LLAMA_ASSERT(ctx->model.kv_self.buf.size == n_size);
- void * k_data = ctx->model.kv_self.k->data; // remember data pointers
- void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
- memcpy(ctx->model.kv_self.buf.addr, kv_cache, n_size);
- ctx->model.kv_self.k->data = k_data; // restore correct data pointers
- ctx->model.kv_self.v->data = v_data;
- ctx->model.kv_self.n = n_token_count;
+// Sets the state reading from the specified source address
+size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
+ const uint8_t * in = src;
+
+ // set rng
+ {
+ size_t rng_size;
+ char rng_buf[LLAMA_MAX_RNG_STATE];
+
+ memcpy(&rng_size, in, sizeof(rng_size)); in += sizeof(rng_size);
+ memcpy(&rng_buf[0], in, LLAMA_MAX_RNG_STATE); in += LLAMA_MAX_RNG_STATE;
+
+ std::stringstream rng_ss;
+ rng_ss.str(std::string(&rng_buf[0], rng_size));
+ rng_ss >> ctx->rng;
+
+ LLAMA_ASSERT(rng_ss.fail() == false);
+ }
+
+ // set logits
+ {
+ size_t logits_cap;
+ size_t logits_size;
+
+ memcpy(&logits_cap, in, sizeof(logits_cap)); in += sizeof(logits_cap);
+ memcpy(&logits_size, in, sizeof(logits_size)); in += sizeof(logits_size);
+
+ LLAMA_ASSERT(ctx->logits.capacity() == logits_cap);
+
+ if (logits_size) {
+ ctx->logits.resize(logits_size);
+ memcpy(ctx->logits.data(), in, logits_size * sizeof(float));
+ }
+
+ in += logits_cap * sizeof(float);
+ }
+
+ // set embeddings
+ {
+ size_t embedding_size;
+
+ memcpy(&embedding_size, in, sizeof(embedding_size)); in += sizeof(embedding_size);
+
+ LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);
+
+ if (embedding_size) {
+ memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float));
+ in += embedding_size * sizeof(float);
+ }
+ }
+
+ // set kv cache
+ {
+ size_t kv_size;
+ int kv_ntok;
+
+ memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size);
+ memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
+
+ if (kv_size) {
+ LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);
+
+ void * k_data = ctx->model.kv_self.k->data; // remember data pointers
+ void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
+
+ memcpy(ctx->model.kv_self.buf.addr, in, kv_size); in += kv_size;
+
+ ctx->model.kv_self.k->data = k_data; // restore correct data pointers
+ ctx->model.kv_self.v->data = v_data;
+
+ }
+
+ ctx->model.kv_self.n = kv_ntok;
+ }
+
+ const size_t nread = in - src;
+ const size_t expected = llama_get_state_size(ctx);
+
+ LLAMA_ASSERT(nread == expected);
+
+ return nread;
}
int llama_eval(
return ctx->model.tensors_by_name;
}
-// Returns the size of the state
-size_t llama_get_state_size(struct llama_context * ctx) {
- // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
- // for reference, std::mt19937(1337) serializes to 6701 bytes.
- const size_t s_rng_size = sizeof(size_t);
- const size_t s_rng = 64*1024;
- const size_t s_logits_capacity = sizeof(size_t);
- const size_t s_logits_size = sizeof(size_t);
- const size_t s_logits = ctx->logits.capacity() * sizeof(float);
- const size_t s_embedding_size = sizeof(size_t);
- const size_t s_embedding = ctx->embedding.size() * sizeof(float);
- const size_t s_kv_size = sizeof(size_t);
- const size_t s_kv_ntok = sizeof(int);
- const size_t s_kv = llama_get_kv_cache_size(ctx);
- const size_t s_total = (
- + s_rng_size
- + s_rng
- + s_logits_capacity
- + s_logits_size
- + s_logits
- + s_embedding_size
- + s_embedding
- + s_kv_size
- + s_kv_ntok
- + s_kv
- );
- return s_total;
-}
-
-// Copies the state to the specified destination address
-size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
- std::stringstream rng_ss;
- rng_ss << ctx->rng;
- const size_t rng_size = rng_ss.str().size();
- char rng_buf[64*1024];
- memset(&rng_buf[0], 0, 64*1024);
- memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
- const size_t logits_capacity = ctx->logits.capacity();
- const size_t logits_size = ctx->logits.size();
- const size_t embedding_size = ctx->embedding.size();
- const size_t kv_size = llama_get_kv_cache_size(ctx);
- const int kv_ntok = llama_get_kv_cache_token_count(ctx);
-
- uint8_t * out = dest;
- memcpy(out, &rng_size, sizeof(size_t)); out += sizeof(size_t);
- memcpy(out, &rng_buf[0], 64*1024); out += 64*1024;
- memcpy(out, &logits_capacity, sizeof(size_t)); out += sizeof(size_t);
- memcpy(out, &logits_size, sizeof(size_t)); out += sizeof(size_t);
- if (logits_size) {
- memcpy(out, ctx->logits.data(), logits_size * sizeof(float));
- }
- out += logits_capacity * sizeof(float);
- memcpy(out, &embedding_size, sizeof(size_t)); out += sizeof(size_t);
- if (embedding_size) {
- memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); out += embedding_size * sizeof(float);
- }
- memcpy(out, &kv_size, sizeof(size_t)); out += sizeof(size_t);
- memcpy(out, &kv_ntok, sizeof(int)); out += sizeof(int);
- if (kv_size) {
- memcpy(out, llama_get_kv_cache(ctx), kv_size); out += kv_size;
- }
- const size_t written = out - dest;
- const size_t expected = llama_get_state_size(ctx);
- LLAMA_ASSERT(written == expected);
- return written;
-}
-
-// Sets the state reading from the specified source address
-size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
- size_t rng_size;
- char rng_buf[64*1024];
- std::stringstream rng_ss;
-
- const uint8_t * in = src;
- memcpy(&rng_size, in, sizeof(size_t)); in += sizeof(size_t);
- memcpy(&rng_buf[0], in, 64*1024); in += 64*1024;
- rng_ss.str(std::string(&rng_buf[0], rng_size));
- rng_ss >> ctx->rng;
- LLAMA_ASSERT(rng_ss.fail() == false);
-
- size_t logits_capacity;
- size_t logits_size;
- size_t embedding_size;
- size_t kv_size;
- int kv_ntok;
-
- memcpy(&logits_capacity, in, sizeof(size_t)); in += sizeof(size_t);
- memcpy(&logits_size, in, sizeof(size_t)); in += sizeof(size_t);
- LLAMA_ASSERT(ctx->logits.capacity() == logits_capacity);
- if (logits_size) {
- ctx->logits.resize(logits_size);
- memcpy(ctx->logits.data(), in, logits_size * sizeof(float));
- }
- in += logits_capacity * sizeof(float);
- memcpy(&embedding_size, in, sizeof(size_t)); in += sizeof(size_t);
- LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);
- if (embedding_size) {
- memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float));
- in += embedding_size * sizeof(float);
- }
- memcpy(&kv_size, in, sizeof(size_t)); in += sizeof(size_t);
- memcpy(&kv_ntok, in, sizeof(int)); in += sizeof(int);
- if (kv_size) {
- LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);
- void * k_data = ctx->model.kv_self.k->data; // remember data pointers
- void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
- memcpy(ctx->model.kv_self.buf.addr, in, kv_size);
- ctx->model.kv_self.k->data = k_data; // restore correct data pointers
- ctx->model.kv_self.v->data = v_data;
- in += kv_size;
- }
- ctx->model.kv_self.n = kv_ntok;
- const size_t nread = in - src;
- const size_t expected = llama_get_state_size(ctx);
- LLAMA_ASSERT(nread == expected);
- return nread;
-}