// TODO: find a nicer way to add other recurrent model architectures
cache.recurrent = model.arch == LLM_ARCH_MAMBA;
- cache.v_trans = !cparams.flash_attn;
+ cache.v_trans = !cache.recurrent && !cparams.flash_attn;
cache.head = 0;
cache.size = kv_size;
}
// deprecated
-size_t llama_get_state_size(const struct llama_context * ctx) {
+size_t llama_get_state_size(struct llama_context * ctx) {
return llama_state_get_size(ctx);
}
// deprecated
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
- return llama_state_get_data(ctx, dst);
+ return llama_state_get_data(ctx, dst, -1);
}
// deprecated
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
- return llama_state_set_data(ctx, src);
+ return llama_state_set_data(ctx, src, -1);
}
// deprecated
return llama_state_save_file(ctx, path_session, tokens, n_token_count);
}
-// Returns the *maximum* size of the state
-size_t llama_state_get_size(const struct llama_context * ctx) {
- const auto & cparams = ctx->cparams;
- const auto & hparams = ctx->model.hparams;
-
- // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
- // for reference, std::mt19937(1337) serializes to 6701 bytes.
- const size_t s_rng_size = sizeof(size_t);
- const size_t s_rng = LLAMA_MAX_RNG_STATE;
- const size_t s_n_outputs = sizeof(size_t);
- // assume worst case for outputs although only currently set ones are serialized
- const size_t s_output_pos = ctx->cparams.n_batch * sizeof(int32_t);
- const size_t s_logits_size = sizeof(size_t);
- const size_t s_logits = ctx->logits_size ? cparams.n_batch * hparams.n_vocab * sizeof(float) : 0;
- const size_t s_embedding_size = sizeof(size_t);
- const size_t s_embedding = ctx->embd_size ? cparams.n_batch * hparams.n_embd * sizeof(float) : 0;
- const size_t s_kv_buf_size = sizeof(size_t);
- const size_t s_kv_head = sizeof(uint32_t);
- const size_t s_kv_size = sizeof(uint32_t);
- const size_t s_kv_used = sizeof(uint32_t);
- const size_t s_v_trans = sizeof(uint32_t);
- const size_t s_kv = ctx->kv_self.total_size();
- const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id);
- const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell;
-
- const size_t s_total = (
- + s_rng_size
- + s_rng
- + s_n_outputs
- + s_output_pos
- + s_logits_size
- + s_logits
- + s_embedding_size
- + s_embedding
- + s_kv_buf_size
- + s_kv_head
- + s_kv_size
- + s_kv_used
- + s_v_trans
- + s_kv
- + s_kv_cells
- );
-
- // on session change it is very likely that the state size has changed - so we need to update this function
- static_assert(LLAMA_SESSION_VERSION == 7, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?");
-
- return s_total;
-}
-
-// llama_context_data
-struct llama_data_context {
+// TODO: replace all non-fatal assertions with returned errors or exceptions
+struct llama_data_write {
virtual void write(const void * src, size_t size) = 0;
virtual size_t get_size_written() = 0;
- virtual ~llama_data_context() = default;
-};
+ virtual ~llama_data_write() = default;
-struct llama_data_buffer_context : llama_data_context {
- uint8_t * ptr;
- size_t size_written = 0;
+ void write_string(const std::string & str) {
+ uint32_t str_size = str.size();
- llama_data_buffer_context(uint8_t * p) : ptr(p) {}
-
- void write(const void * src, size_t size) override {
- memcpy(ptr, src, size);
- ptr += size;
- size_written += size;
+ write(&str_size, sizeof(str_size));
+ write(str.data(), str_size);
}
- size_t get_size_written() override {
- return size_written;
+ void write_model_info(const struct llama_context * ctx) {
+ std::string arch_str = LLM_ARCH_NAMES.at(ctx->model.arch);
+ write_string(arch_str);
+ // TODO: add more model-specific info which should prevent loading the session file if not identical
}
-};
-struct llama_data_file_context : llama_data_context {
- llama_file * file;
- size_t size_written = 0;
+ void write_rng(const std::mt19937 & rng) {
+ std::ostringstream rng_ss;
+ rng_ss << rng;
- llama_data_file_context(llama_file * f) : file(f) {}
+ const std::string & rng_str = rng_ss.str();
- void write(const void * src, size_t size) override {
- file->write_raw(src, size);
- size_written += size;
+ write_string(rng_str);
}
- size_t get_size_written() override {
- return size_written;
- }
-};
+ void write_output_ids(const struct llama_context * ctx) {
+ const uint32_t n_outputs = ctx->n_outputs;
-/** copy state data into either a buffer or file depending on the passed in context
- *
- * file context:
- * llama_file file("/path", "wb");
- * llama_data_file_context data_ctx(&file);
- * llama_state_get_data(ctx, &data_ctx);
- *
- * buffer context:
- * std::vector<uint8_t> buf(max_size, 0);
- * llama_data_buffer_context data_ctx(&buf.data());
- * llama_state_get_data(ctx, &data_ctx);
- *
-*/
-static void llama_state_get_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
- llama_synchronize(ctx);
+ std::vector<int32_t> output_pos;
- // copy rng
- {
- std::ostringstream rng_ss;
- rng_ss << ctx->sampling.rng;
+ const size_t n_batch = ctx->cparams.n_batch;
+ const auto & output_ids = ctx->output_ids;
+
+ GGML_ASSERT(n_outputs <= ctx->output_size);
+
+ output_pos.resize(n_outputs);
+
+ // build a more compact representation of the output ids
+ for (size_t i = 0; i < n_batch; ++i) {
+ // map an output id to a position in the batch
+ int32_t pos = output_ids[i];
+ if (pos >= 0) {
+ GGML_ASSERT((uint32_t) pos < n_outputs);
+ output_pos[pos] = i;
+ }
+ }
- const std::string & rng_str = rng_ss.str();
- const size_t rng_size = rng_str.size();
+ write(&n_outputs, sizeof(n_outputs));
- GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
+ if (n_outputs) {
+ write(output_pos.data(), n_outputs * sizeof(int32_t));
+ }
+ }
- data_ctx->write(&rng_size, sizeof(rng_size));
- data_ctx->write(rng_str.data(), rng_size);
+ void write_logits(const struct llama_context * ctx) {
+ const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab);
+
+ write(&logits_size, sizeof(logits_size));
+
+ if (logits_size) {
+ write(ctx->logits, logits_size * sizeof(float));
+ }
}
- // copy outputs
- {
- // Can't use ctx->n_outputs because it's not for the
- // entire last batch when n_ubatch is smaller than n_batch
- size_t n_outputs = 0;
+ void write_embeddings(const struct llama_context * ctx) {
+ const uint64_t embeddings_size = std::min((uint64_t) ctx->embd_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_embd);
- // copy output ids
- {
- std::vector<int32_t> output_pos;
+ write(&embeddings_size, sizeof(embeddings_size));
+
+ if (embeddings_size) {
+ write(ctx->embd, embeddings_size * sizeof(float));
+ }
+ }
+
+ void write_kv_cache_meta(const llama_kv_cache & kv_self, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {
- const size_t n_batch = ctx->cparams.n_batch;
- const auto & output_ids = ctx->output_ids;
+ for (const auto & range : cell_ranges) {
+ for (uint32_t i = range.first; i < range.second; ++i) {
+ const auto & cell = kv_self.cells[i];
+ const llama_pos pos = cell.pos;
+ const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
- output_pos.resize(ctx->output_size);
+ write(&pos, sizeof(pos));
+ write(&n_seq_id, sizeof(n_seq_id));
- // build a more compact representation of the output ids
- for (size_t i = 0; i < n_batch; ++i) {
- // map an output id to a position in the batch
- int32_t pos = output_ids[i];
- if (pos >= 0) {
- if ((size_t) pos >= n_outputs) {
- n_outputs = pos + 1;
+ if (n_seq_id) {
+ for (auto seq_id : cell.seq_id) {
+ write(&seq_id, sizeof(seq_id));
}
- GGML_ASSERT((size_t) pos < ctx->output_size);
- output_pos[pos] = i;
}
}
+ }
+ }
- data_ctx->write(&n_outputs, sizeof(n_outputs));
+ void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) {
+ const struct llama_kv_cache & kv_self = ctx->kv_self;
+ const struct llama_hparams & hparams = ctx->model.hparams;
- if (n_outputs) {
- data_ctx->write(output_pos.data(), n_outputs * sizeof(int32_t));
- }
- }
+ const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
+ const uint32_t n_layer = hparams.n_layer;
- // copy logits
- {
- const size_t logits_size = std::min(ctx->logits_size, n_outputs * ctx->model.hparams.n_vocab);
+ write(&v_trans, sizeof(v_trans));
+ write(&n_layer, sizeof(n_layer));
- data_ctx->write(&logits_size, sizeof(logits_size));
+ std::vector<uint8_t> tmp_buf;
- if (logits_size) {
- data_ctx->write(ctx->logits, logits_size * sizeof(float));
- }
- }
+ // Iterate and write all the keys first, each row is a cell
+ // Get whole range at a time
+ for (uint32_t il = 0; il < n_layer; ++il) {
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
- // copy embeddings
- {
- const size_t embeddings_size = std::min(ctx->embd_size, n_outputs * ctx->model.hparams.n_embd);
+ // Write key type
+ const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
+ write(&k_type_i, sizeof(k_type_i));
- data_ctx->write(&embeddings_size, sizeof(embeddings_size));
+ // Write row size of key
+ const uint64_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
+ write(&k_size_row, sizeof(k_size_row));
- if (embeddings_size) {
- data_ctx->write(ctx->embd, embeddings_size * sizeof(float));
+ // Read each range of cells of k_size length each into tmp_buf and write out
+ for (const auto & range : cell_ranges) {
+ const size_t range_size = range.second - range.first;
+ tmp_buf.resize(range_size * k_size_row);
+ ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), range.first * k_size_row, range_size * k_size_row);
+ write(tmp_buf.data(), tmp_buf.size());
}
}
- }
- // copy kv cache
- {
- const auto & kv_self = ctx->kv_self;
- const auto & hparams = ctx->model.hparams;
-
- const uint32_t n_layer = hparams.n_layer;
-
- // NOTE: kv_size and kv_buf_size are mostly used for sanity checks
- const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
- const uint32_t kv_size = kv_self.size;
- const size_t kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head;
- const uint32_t kv_used = kv_self.used;
- const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
-
- data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
- data_ctx->write(&kv_head, sizeof(kv_head));
- data_ctx->write(&kv_size, sizeof(kv_size));
- data_ctx->write(&kv_used, sizeof(kv_used));
- data_ctx->write(&v_trans, sizeof(v_trans));
-
- if (kv_buf_size) {
- const size_t pre_kv_buf_size = data_ctx->get_size_written();
-
- std::vector<uint8_t> tmp_buf;
- for (int il = 0; il < (int) n_layer; ++il) {
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+ if (!kv_self.v_trans) {
+ for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
- const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
+ // Write value type
+ const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+ write(&v_type_i, sizeof(v_type_i));
- tmp_buf.resize(k_size);
- ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size());
- data_ctx->write(tmp_buf.data(), tmp_buf.size());
+ // Write row size of value
+ const uint64_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
+ write(&v_size_row, sizeof(v_size_row));
- if (kv_self.recurrent || !kv_self.v_trans) {
- // v is contiguous for recurrent models
- // TODO: use other tensors for state models than k and v
- const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
-
- tmp_buf.resize(v_size);
- ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), 0, tmp_buf.size());
- data_ctx->write(tmp_buf.data(), tmp_buf.size());
- continue;
+ // Read each range of cells of v_size length each into tmp_buf and write out
+ for (const auto & range : cell_ranges) {
+ const size_t range_size = range.second - range.first;
+ tmp_buf.resize(range_size * v_size_row);
+ ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row);
+ write(tmp_buf.data(), tmp_buf.size());
}
+ }
+ } else {
+ // When v is transposed, we also need the element size and get the element ranges from each row
+ const uint32_t kv_size = kv_self.size;
+ for (uint32_t il = 0; il < n_layer; ++il) {
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
- // v is not contiguous, copy row by row
- const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
- const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
+ // Write value type
+ const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+ write(&v_type_i, sizeof(v_type_i));
- tmp_buf.resize(v_row_size);
- for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
- ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), ir*v_row_stride, tmp_buf.size());
- data_ctx->write(tmp_buf.data(), tmp_buf.size());
+ // Write element size
+ const uint32_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
+ write(&v_size_el, sizeof(v_size_el));
+
+ // Write GQA embedding size
+ write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
+
+ // For each row, we get the element values of each cell
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+ // Read each range of cells of v_size_el length each into tmp_buf and write out
+ for (const auto & range : cell_ranges) {
+ const size_t range_size = range.second - range.first;
+ const size_t src_offset = (range.first + j * kv_size) * v_size_el;
+ tmp_buf.resize(range_size * v_size_el);
+ ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
+ write(tmp_buf.data(), tmp_buf.size());
+ }
}
}
- GGML_ASSERT(kv_buf_size == data_ctx->get_size_written() - pre_kv_buf_size);
}
+ }
- for (uint32_t i = 0; i < kv_head; ++i) {
- const auto & cell = kv_self.cells[i];
-
- const llama_pos pos = cell.pos;
- const size_t seq_id_size = cell.seq_id.size();
-
- data_ctx->write(&pos, sizeof(pos));
- data_ctx->write(&seq_id_size, sizeof(seq_id_size));
+ void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
+ const struct llama_kv_cache & kv_self = ctx->kv_self;
+ std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
+ uint32_t cell_count = 0;
- for (auto seq_id : cell.seq_id) {
- data_ctx->write(&seq_id, sizeof(seq_id));
+ // Count the number of cells with the specified seq_id
+ // Find all the ranges of cells with this seq id (or all, when -1)
+ uint32_t cell_range_begin = kv_self.size;
+ for (uint32_t i = 0; i < kv_self.size; ++i) {
+ const auto & cell = kv_self.cells[i];
+ if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
+ ++cell_count;
+ if (cell_range_begin == kv_self.size) {
+ cell_range_begin = i;
+ }
+ } else {
+ if (cell_range_begin != kv_self.size) {
+ cell_ranges.emplace_back(cell_range_begin, i);
+ cell_range_begin = kv_self.size;
+ }
}
}
- }
-}
+ if (cell_range_begin != kv_self.size) {
+ cell_ranges.emplace_back(cell_range_begin, kv_self.size);
+ }
-size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst) {
- llama_data_buffer_context data_ctx(dst);
- llama_state_get_data_internal(ctx, &data_ctx);
+ // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
+ uint32_t cell_count_check = 0;
+ for (const auto & range : cell_ranges) {
+ cell_count_check += range.second - range.first;
+ }
+ GGML_ASSERT(cell_count == cell_count_check);
- return data_ctx.get_size_written();
-}
+ write(&cell_count, sizeof(cell_count));
-// Sets the state reading from the specified source address
-size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
- llama_synchronize(ctx);
+ write_kv_cache_meta(kv_self, cell_ranges, seq_id);
+ write_kv_cache_data(ctx, cell_ranges);
+ }
+};
- const uint8_t * inp = src;
+struct llama_data_read {
+ virtual const uint8_t * read(size_t size) = 0;
+ virtual void read_to(void * dst, size_t size) = 0;
+ virtual size_t get_size_read() = 0;
+ virtual ~llama_data_read() = default;
- // set rng
- {
- size_t rng_size;
- memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size);
+ void read_string(std::string & str) {
+ uint32_t str_size;
+ read_to(&str_size, sizeof(str_size));
- GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
+ str.assign((const char *) read(str_size), str_size);
+ }
+
+ // validate model information
+ void read_model_info(const struct llama_context * ctx) {
+ std::string cur_arch_str = LLM_ARCH_NAMES.at(ctx->model.arch);
+ std::string arch_str;
+ read_string(arch_str);
+ if (cur_arch_str != arch_str) {
+ throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
+ }
+ // TODO: add more info which needs to be identical but which is not verified otherwise
+ }
- std::string rng_str((const char *)inp, rng_size); inp += rng_size;
+ void read_rng(std::mt19937 & rng) {
+ std::string rng_str;
+ read_string(rng_str);
std::istringstream rng_ss(rng_str);
- rng_ss >> ctx->sampling.rng;
+ rng_ss >> rng;
- GGML_ASSERT(!rng_ss.fail());
+ if (rng_ss.fail()) {
+ throw std::runtime_error("failed to load RNG state");
+ }
}
- // set output ids
- {
- size_t n_outputs;
+ void read_output_ids(struct llama_context * ctx) {
std::vector<int32_t> output_pos;
- memcpy(&n_outputs, inp, sizeof(n_outputs)); inp += sizeof(n_outputs);
+ uint32_t n_outputs;
+ read_to(&n_outputs, sizeof(n_outputs));
- GGML_ASSERT(n_outputs <= llama_output_reserve(*ctx, n_outputs));
+ if (n_outputs > llama_output_reserve(*ctx, n_outputs)) {
+ throw std::runtime_error("could not reserve outputs");
+ }
if (n_outputs) {
output_pos.resize(n_outputs);
- memcpy(output_pos.data(), inp, n_outputs * sizeof(int32_t));
- inp += n_outputs * sizeof(int32_t);
+ read_to(output_pos.data(), n_outputs * sizeof(int32_t));
for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
int32_t id = output_pos[i];
- GGML_ASSERT((uint32_t) id < ctx->cparams.n_batch);
+ if ((uint32_t) id >= ctx->cparams.n_batch) {
+ throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->cparams.n_batch));
+ }
ctx->output_ids[id] = i;
}
}
}
- // set logits
- {
- size_t logits_size;
-
- memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size);
+ void read_logits(struct llama_context * ctx) {
+ uint64_t logits_size;
+ read_to(&logits_size, sizeof(logits_size));
- GGML_ASSERT(ctx->logits_size >= logits_size);
+ if (ctx->logits_size < logits_size) {
+ throw std::runtime_error("logits buffer too small");
+ }
if (logits_size) {
- memcpy(ctx->logits, inp, logits_size * sizeof(float));
- inp += logits_size * sizeof(float);
+ read_to(ctx->logits, logits_size * sizeof(float));
}
}
- // set embeddings
- {
- size_t embeddings_size;
-
- memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size);
+ void read_embeddings(struct llama_context * ctx) {
+ uint64_t embeddings_size;
+ read_to(&embeddings_size, sizeof(embeddings_size));
- GGML_ASSERT(ctx->embd_size >= embeddings_size);
+ if (ctx->embd_size < embeddings_size) {
+ throw std::runtime_error("embeddings buffer too small");
+ }
if (embeddings_size) {
- memcpy(ctx->embd, inp, embeddings_size * sizeof(float));
- inp += embeddings_size * sizeof(float);
+ read_to(ctx->embd, embeddings_size * sizeof(float));
}
}
- // set kv cache
- {
- const auto & kv_self = ctx->kv_self;
- const auto & hparams = ctx->model.hparams;
+ bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) {
+ struct llama_kv_cache & kv_self = ctx->kv_self;
- const uint32_t n_layer = hparams.n_layer;
+ if (dest_seq_id != -1) {
+ // single sequence
- size_t kv_buf_size;
- uint32_t kv_head;
- uint32_t kv_size;
- uint32_t kv_used;
- uint32_t v_trans;
+ llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
+
+ llama_batch batch = llama_batch_init(cell_count, 0, 1);
+ batch.n_tokens = cell_count;
+ for (uint32_t i = 0; i < cell_count; ++i) {
+ llama_pos pos;
+ uint32_t n_seq_id;
- memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
- memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head);
- memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
- memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used);
- memcpy(&v_trans, inp, sizeof(v_trans)); inp += sizeof(v_trans);
+ read_to(&pos, sizeof(pos));
+ read_to(&n_seq_id, sizeof(n_seq_id));
- GGML_ASSERT(kv_self.v_trans == (bool) v_trans); // incompatible V transposition
+ if (n_seq_id != 0) {
+ LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
+ return false;
+ }
- if (kv_self.size != kv_size) {
- // the KV cache needs to be big enough to load all the KV cells from the saved state
- GGML_ASSERT(kv_self.size >= kv_head);
+ batch.pos[i] = pos;
+ batch.n_seq_id[i] = 1;
+ batch.seq_id[i][0] = dest_seq_id;
+ }
+ if (!llama_kv_cache_find_slot(kv_self, batch)) {
+ llama_batch_free(batch);
+ LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
+ return false;
+ }
+
+ // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
+ // Assume that this is one contiguous block of cells
+ GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
+ GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
+ GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
+ GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
+ GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
- LLAMA_LOG_INFO("%s: state contains %d KV cells, was saved with kv_size=%d, but is loaded with kv_size=%d (fine, but different)\n",
- __func__, kv_head, kv_size, kv_self.size);
+ // Cleanup
+ llama_batch_free(batch);
+ } else {
+ // whole KV cache restore
+
+ if (cell_count > kv_self.size) {
+ LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
+ return false;
+ }
+
+ llama_kv_cache_clear(kv_self);
+
+ for (uint32_t i = 0; i < cell_count; ++i) {
+ llama_kv_cell & cell = kv_self.cells[i];
+
+ llama_pos pos;
+ uint32_t n_seq_id;
+
+ read_to(&pos, sizeof(pos));
+ read_to(&n_seq_id, sizeof(n_seq_id));
+
+ cell.pos = pos;
+
+ for (uint32_t j = 0; j < n_seq_id; ++j) {
+ llama_seq_id seq_id;
+ read_to(&seq_id, sizeof(seq_id));
+
+ if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
+ LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
+ return false;
+ }
+
+ cell.seq_id.insert(seq_id);
+ }
+ }
+
+ kv_self.head = 0;
+ kv_self.used = cell_count;
+ }
+
+ return true;
+ }
+
+ bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
+ const struct llama_hparams & hparams = ctx->model.hparams;
+ struct llama_kv_cache & kv_self = ctx->kv_self;
+ uint32_t v_trans;
+ uint32_t n_layer;
+ read_to(&v_trans, sizeof(v_trans));
+ read_to(&n_layer, sizeof(n_layer));
+
+ if (n_layer != hparams.n_layer) {
+ LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
+ return false;
+ }
+ if (cell_count > kv_self.size) {
+ LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
+ return false;
+ }
+ if (kv_self.v_trans != (bool) v_trans) {
+ LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
+ return false;
}
- llama_kv_cache_clear(ctx);
+ // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
+ for (uint32_t il = 0; il < n_layer; ++il) {
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
- if (kv_buf_size) {
- const size_t pre_kv_buf_size = inp - src;
+ // Read type of key
+ int32_t k_type_i_ref;
+ read_to(&k_type_i_ref, sizeof(k_type_i_ref));
+ const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
+ if (k_type_i != k_type_i_ref) {
+ LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
+ return false;
+ }
+
+ // Read row size of key
+ uint64_t k_size_row_ref;
+ read_to(&k_size_row_ref, sizeof(k_size_row_ref));
+ const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
+ if (k_size_row != k_size_row_ref) {
+ LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
+ return false;
+ }
- GGML_ASSERT(kv_self.total_size() >= kv_buf_size);
+ if (cell_count) {
+ // Read and set the keys for the whole cell range
+ ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row);
+ }
+ }
- for (int il = 0; il < (int) n_layer; ++il) {
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+ if (!kv_self.v_trans) {
+ for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
- const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
+ // Read type of value
+ int32_t v_type_i_ref;
+ read_to(&v_type_i_ref, sizeof(v_type_i_ref));
+ const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+ if (v_type_i != v_type_i_ref) {
+ LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+ return false;
+ }
- ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
- inp += k_size;
+ // Read row size of value
+ uint64_t v_size_row_ref;
+ read_to(&v_size_row_ref, sizeof(v_size_row_ref));
+ const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
+ if (v_size_row != v_size_row_ref) {
+ LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
+ return false;
+ }
- if (kv_self.recurrent || !kv_self.v_trans) {
- // v is contiguous for recurrent models
- // TODO: use other tensors for state models than k and v
- const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
+ if (cell_count) {
+ // Read and set the values for the whole cell range
+ ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
+ }
+ }
+ } else {
+ // For each layer, read the values for each cell (transposed)
+ for (uint32_t il = 0; il < n_layer; ++il) {
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
- ggml_backend_tensor_set(kv_self.v_l[il], inp, 0, v_size);
- inp += v_size;
- continue;
+ // Read type of value
+ int32_t v_type_i_ref;
+ read_to(&v_type_i_ref, sizeof(v_type_i_ref));
+ const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+ if (v_type_i != v_type_i_ref) {
+ LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+ return false;
}
- // v is not contiguous, copy row by row
- const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
- const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_self.size);
+ // Read element size of value
+ uint32_t v_size_el_ref;
+ read_to(&v_size_el_ref, sizeof(v_size_el_ref));
+ const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
+ if (v_size_el != v_size_el_ref) {
+ LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
+ return false;
+ }
+
+ // Read GQA embedding size
+ uint32_t n_embd_v_gqa_ref;
+ read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
+ if (n_embd_v_gqa != n_embd_v_gqa_ref) {
+ LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
+ return false;
+ }
- for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
- ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size);
- inp += v_row_size;
+ if (cell_count) {
+ // For each row in the transposed matrix, read the values for the whole cell range
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+ const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el;
+ ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
+ }
}
}
- GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size);
}
+ return true;
+ }
- ctx->kv_self.head = kv_head;
- ctx->kv_self.used = kv_used;
+ void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
+ uint32_t cell_count;
+ read_to(&cell_count, sizeof(cell_count));
- for (uint32_t i = 0; i < kv_head; ++i) {
- llama_pos pos;
- size_t seq_id_size;
+ bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count);
- memcpy(&pos, inp, sizeof(pos)); inp += sizeof(pos);
- memcpy(&seq_id_size, inp, sizeof(seq_id_size)); inp += sizeof(seq_id_size);
+ if (!res) {
+ if (seq_id == -1) {
+ llama_kv_cache_clear(ctx);
+ } else {
+ llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
+ }
+ throw std::runtime_error("failed to restore kv cache");
+ }
+ }
+};
- ctx->kv_self.cells[i].pos = pos;
+struct llama_data_write_dummy : llama_data_write {
+ size_t size_written = 0;
- llama_seq_id seq_id;
+ llama_data_write_dummy() {}
- for (size_t j = 0; j < seq_id_size; ++j) {
- memcpy(&seq_id, inp, sizeof(seq_id)); inp += sizeof(seq_id);
- ctx->kv_self.cells[i].seq_id.insert(seq_id);
- }
+ // TODO: avoid unnecessary calls to ggml_backend_tensor_get in a dummy context
+
+ void write(const void * /* src */, size_t size) override {
+ size_written += size;
+ }
+
+ size_t get_size_written() override {
+ return size_written;
+ }
+};
+
+struct llama_data_write_buffer : llama_data_write {
+ uint8_t * ptr;
+ size_t buf_size = 0;
+ size_t size_written = 0;
+
+ llama_data_write_buffer(uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
+
+ void write(const void * src, size_t size) override {
+ if (size > buf_size) {
+ throw std::runtime_error("unexpectedly reached end of buffer");
}
+ memcpy(ptr, src, size);
+ ptr += size;
+ size_written += size;
+ buf_size -= size;
}
- const size_t nread = inp - src;
- const size_t max_size = llama_state_get_size(ctx);
+ size_t get_size_written() override {
+ return size_written;
+ }
+};
- GGML_ASSERT(nread <= max_size);
+struct llama_data_read_buffer : llama_data_read {
+ const uint8_t * ptr;
+ size_t buf_size = 0;
+ size_t size_read = 0;
- return nread;
+ llama_data_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
+
+ const uint8_t * read(size_t size) override {
+ const uint8_t * base_ptr = ptr;
+ if (size > buf_size) {
+ throw std::runtime_error("unexpectedly reached end of buffer");
+ }
+ ptr += size;
+ size_read += size;
+ buf_size -= size;
+ return base_ptr;
+ }
+
+ void read_to(void * dst, size_t size) override {
+ memcpy(dst, read(size), size);
+ }
+
+ size_t get_size_read() override {
+ return size_read;
+ }
+};
+
+struct llama_data_write_file : llama_data_write {
+ llama_file * file;
+ size_t size_written = 0;
+
+ llama_data_write_file(llama_file * f) : file(f) {}
+
+ void write(const void * src, size_t size) override {
+ file->write_raw(src, size);
+ size_written += size;
+ }
+
+ size_t get_size_written() override {
+ return size_written;
+ }
+};
+
+struct llama_data_read_file : llama_data_read {
+ llama_file * file;
+ size_t size_read = 0;
+ std::vector<uint8_t> temp_buffer;
+
+ llama_data_read_file(llama_file * f) : file(f) {}
+
+ void read_to(void * dst, size_t size) override {
+ file->read_raw(dst, size);
+ size_read += size;
+ }
+
+ const uint8_t * read(size_t size) override {
+ temp_buffer.resize(size);
+ read_to(temp_buffer.data(), size);
+ return temp_buffer.data();
+ }
+
+ size_t get_size_read() override {
+ return size_read;
+ }
+};
+
+/** copy state data into either a buffer or file depending on the passed in context
+ *
+ * file context:
+ * llama_file file("/path", "wb");
+ * llama_data_write_file data_ctx(&file);
+ * llama_state_get_data_internal(ctx, data_ctx);
+ *
+ * buffer context:
+ * std::vector<uint8_t> buf(max_size, 0);
+ * llama_data_write_buffer data_ctx(buf.data(), max_size);
+ * llama_state_get_data_internal(ctx, data_ctx);
+ *
+*/
+static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx) {
+ llama_synchronize(ctx);
+
+ data_ctx.write_model_info(ctx);
+
+ data_ctx.write_rng(ctx->sampling.rng);
+
+ // copy outputs
+ data_ctx.write_output_ids(ctx);
+ data_ctx.write_logits(ctx);
+ data_ctx.write_embeddings(ctx);
+
+ data_ctx.write_kv_cache(ctx);
+
+ return data_ctx.get_size_written();
+}
+
+size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) {
+ llama_data_write_buffer data_ctx(dst, size);
+ try {
+ return llama_state_get_data_internal(ctx, data_ctx);
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
+ return 0;
+ }
+}
+
+// Returns the *actual* size of the state.
+// Intended to be used when saving to state to a buffer.
+size_t llama_state_get_size(struct llama_context * ctx) {
+ llama_data_write_dummy data_ctx;
+ try {
+ return llama_state_get_data_internal(ctx, data_ctx);
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
+ return 0;
+ }
+}
+
+static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) {
+ llama_synchronize(ctx);
+
+ data_ctx.read_model_info(ctx);
+
+ // set rng
+ data_ctx.read_rng(ctx->sampling.rng);
+
+ // set outputs
+ data_ctx.read_output_ids(ctx);
+ data_ctx.read_logits(ctx);
+ data_ctx.read_embeddings(ctx);
+
+ data_ctx.read_kv_cache(ctx);
+
+ return data_ctx.get_size_read();
+}
+
+// Sets the state reading from the specified source address
+size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) {
+ llama_data_read_buffer data_ctx(src, size);
+ try {
+ return llama_state_set_data_internal(ctx, data_ctx);
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
+ return 0;
+ }
}
static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
const uint32_t version = file.read_u32();
if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
- LLAMA_LOG_ERROR("%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
- return false;
- }
-
- llama_hparams session_hparams;
- file.read_raw(&session_hparams, sizeof(llama_hparams));
-
- if (session_hparams != ctx->model.hparams) {
- LLAMA_LOG_INFO("%s : model hparams didn't match from session file!\n", __func__);
+ LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
return false;
}
}
const uint32_t n_token_count = file.read_u32();
if (n_token_count > n_token_capacity) {
- LLAMA_LOG_ERROR("%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
+ LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
return false;
}
// restore the context state
{
const size_t n_state_size_cur = file.size - file.tell();
- const size_t n_state_size_max = llama_state_get_size(ctx);
- if (n_state_size_cur > n_state_size_max) {
- LLAMA_LOG_ERROR("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur);
+ llama_data_read_file data_ctx(&file);
+ const size_t n_read = llama_state_set_data_internal(ctx, data_ctx);
+
+ if (n_read != n_state_size_cur) {
+ LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
return false;
}
-
- std::vector<uint8_t> state_data(n_state_size_max);
- file.read_raw(state_data.data(), n_state_size_cur);
-
- llama_state_set_data(ctx, state_data.data());
}
-
return true;
}
try {
return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
} catch (const std::exception & err) {
- LLAMA_LOG_ERROR("error loading session file: %s\n", err.what());
+ LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
return false;
}
}
file.write_u32(LLAMA_SESSION_MAGIC);
file.write_u32(LLAMA_SESSION_VERSION);
- file.write_raw(&ctx->model.hparams, sizeof(llama_hparams));
-
// save the prompt
file.write_u32((uint32_t) n_token_count);
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
// save the context state using stream saving
- llama_data_file_context data_ctx(&file);
- llama_state_get_data_internal(ctx, &data_ctx);
+ llama_data_write_file data_ctx(&file);
+ llama_state_get_data_internal(ctx, data_ctx);
return true;
}
try {
return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count);
} catch (const std::exception & err) {
- LLAMA_LOG_ERROR("error saving session file: %s\n", err.what());
+ LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
return false;
}
}
-size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) {
- // save the size of size_t as a uint32_t for safety check
- const size_t size_t_size_size = sizeof(uint32_t);
-
- // other values
- const size_t s_cell_count_size = sizeof(uint32_t);
- const size_t s_layer_count_size = sizeof(uint32_t);
- const size_t n_embd_v_gqa_size = sizeof(uint32_t);
-
- size_t s_cell_count = 0;
- size_t s_cell_data_size = 0;
- const auto & kv_self = ctx->kv_self;
- const auto & hparams = ctx->model.hparams;
-
- const uint32_t n_layer = hparams.n_layer;
-
- for (uint32_t i = 0; i < kv_self.size; ++i) {
- const auto & cell = kv_self.cells[i];
- if (cell.seq_id.count(seq_id) > 0) {
- ++s_cell_count;
- s_cell_data_size += sizeof(llama_pos);
- }
- }
-
- for (int il = 0; il < (int)n_layer; ++il) {
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
- // types of keys and values
- s_cell_data_size += sizeof(int32_t) * 2;
- // k_size_row and v_size_el values of layer
- s_cell_data_size += sizeof(size_t) * 2;
-
- // keys
- const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
- s_cell_data_size += k_size_row * s_cell_count;
-
- // values (transposed)
- const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
- s_cell_data_size += v_size_el * s_cell_count * n_embd_v_gqa;
- }
-
- const size_t s_total = (
- size_t_size_size +
- s_cell_count_size +
- s_layer_count_size +
- n_embd_v_gqa_size +
- s_cell_data_size
- );
-
- return s_total;
-}
-
-static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_context & data_ctx, llama_seq_id seq_id) {
+static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
llama_synchronize(ctx);
- const auto & kv_self = ctx->kv_self;
- GGML_ASSERT(!kv_self.recurrent); // not implemented
-
- // Save the size of size_t as a uint32_t for safety check
- const uint32_t size_t_size = sizeof(size_t);
- data_ctx.write(&size_t_size, sizeof(size_t_size));
-
- std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
- uint32_t cell_count = 0;
-
- // Count the number of cells with the specified seq_id
- // Find all the ranges of cells with this seq id
- {
- uint32_t cell_range_begin = kv_self.size;
- for (uint32_t i = 0; i < kv_self.size; ++i) {
- const auto & cell = kv_self.cells[i];
- if (cell.has_seq_id(seq_id)) {
- ++cell_count;
- if (cell_range_begin == kv_self.size) {
- cell_range_begin = i;
- }
- }
- else {
- if (cell_range_begin != kv_self.size) {
- cell_ranges.emplace_back(cell_range_begin, i);
- cell_range_begin = kv_self.size;
- }
- }
- }
- if (cell_range_begin != kv_self.size) {
- cell_ranges.emplace_back(cell_range_begin, kv_self.size);
- }
-
- // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
- uint32_t cell_count_check = 0;
- for (const auto & range : cell_ranges) {
- cell_count_check += range.second - range.first;
- }
- GGML_ASSERT(cell_count == cell_count_check);
- }
-
- // Write the cell count
- data_ctx.write(&cell_count, sizeof(cell_count));
-
- const auto & hparams = ctx->model.hparams;
- const uint32_t n_layer = hparams.n_layer;
-
- // Write the layer count
- data_ctx.write(&n_layer, sizeof(n_layer));
-
- // Write n_embd_v_gqa (reference value)
- {
- const uint32_t n_embd_v_gqa_ref = hparams.n_embd_v_gqa() + hparams.n_embd_k_s();
- data_ctx.write(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
- }
-
- // Iterate the ranges and write all the pos (this is the token position in the prompt)
- for (const auto & range : cell_ranges) {
- for (uint32_t i = range.first; i < range.second; ++i) {
- const auto & cell = kv_self.cells[i];
- data_ctx.write(&cell.pos, sizeof(cell.pos));
- }
- }
-
- // Iterate and write all the keys first, each row is a cell
- // Get whole range at a time
- std::vector<uint8_t> tmp_buf;
- for (int il = 0; il < (int)n_layer; ++il) {
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
-
- // Write key type
- const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
- data_ctx.write(&k_type_i, sizeof(k_type_i));
-
- // Write row size of key
- const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
- data_ctx.write(&k_size_row, sizeof(k_size_row));
-
- // Read each range of cells of k_size length each into tmp_buf and write out
- for (const auto & range : cell_ranges) {
- const size_t range_size = range.second - range.first;
- tmp_buf.resize(range_size * k_size_row);
- ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), range.first * k_size_row, range_size * k_size_row);
- data_ctx.write(tmp_buf.data(), tmp_buf.size());
- }
- }
-
- // TODO: simplify, reduce copy-paste
- if (!kv_self.v_trans) {
- for (int il = 0; il < (int)n_layer; ++il) {
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
- // Write value type
- const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
- data_ctx.write(&v_type_i, sizeof(v_type_i));
-
- // Write row size of value
- const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
- data_ctx.write(&v_size_row, sizeof(v_size_row));
-
- // Read each range of cells of v_size length each into tmp_buf and write out
- for (const auto & range : cell_ranges) {
- const size_t range_size = range.second - range.first;
- tmp_buf.resize(range_size * v_size_row);
- ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row);
- data_ctx.write(tmp_buf.data(), tmp_buf.size());
- }
- }
- } else {
- // For the values, they are transposed, so we also need the element size and get the element ranges from each row
- const uint32_t kv_size = kv_self.size;
- for (int il = 0; il < (int)n_layer; ++il) {
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
- // Write value type
- const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
- data_ctx.write(&v_type_i, sizeof(v_type_i));
-
- // Write element size
- const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
- data_ctx.write(&v_size_el, sizeof(v_size_el));
-
- // For each row, we get the element values of each cell
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
- // Read each range of cells of v_size_el length each into tmp_buf and write out
- for (const auto & range : cell_ranges) {
- const size_t range_size = range.second - range.first;
- const size_t src_offset = (range.first + j * kv_size) * v_size_el;
- tmp_buf.resize(range_size * v_size_el);
- ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
- data_ctx.write(tmp_buf.data(), tmp_buf.size());
- }
- }
- }
- }
+ data_ctx.write_kv_cache(ctx, seq_id);
return data_ctx.get_size_written();
}
-size_t llama_state_seq_get_data(struct llama_context* ctx, uint8_t* dst, llama_seq_id seq_id) {
- llama_data_buffer_context data_ctx(dst);
+size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) {
+ llama_data_write_dummy data_ctx;
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
}
-size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) {
- llama_synchronize(ctx);
-
- auto & kv_self = ctx->kv_self;
- GGML_ASSERT(!kv_self.recurrent); // not implemented
-
- // Wipe the slot
- llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-
- const uint8_t * inp = src;
-
- // Read size of size_t
- uint32_t size_t_size;
- memcpy(&size_t_size, inp, sizeof(size_t_size));
- inp += sizeof(size_t_size);
- if (size_t_size != sizeof(size_t)) {
- LLAMA_LOG_ERROR("%s: size_t size mismatch\n", __func__);
+size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
+ llama_data_write_buffer data_ctx(dst, size);
+ try {
+ return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what());
return 0;
}
+}
- // Read the cell count
- uint32_t cell_count;
- memcpy(&cell_count, inp, sizeof(cell_count));
- inp += sizeof(cell_count);
-
- // Read the layer count
- uint32_t n_layer_ref;
- memcpy(&n_layer_ref, inp, sizeof(n_layer_ref));
- inp += sizeof(n_layer_ref);
-
- // Read n_embd_v_gqa
- uint32_t n_embd_v_gqa_ref;
- memcpy(&n_embd_v_gqa_ref, inp, sizeof(n_embd_v_gqa_ref));
- inp += sizeof(n_embd_v_gqa_ref);
+static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
+ llama_synchronize(ctx);
- // Sanity check model compatibility
- const auto & hparams = ctx->model.hparams;
- const uint32_t n_layer = hparams.n_layer;
+ data_ctx.read_kv_cache(ctx, dest_seq_id);
- if (n_layer != n_layer_ref) {
- LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref);
- return 0;
- }
+ return data_ctx.get_size_read();
+}
- if (hparams.n_embd_v_gqa() != n_embd_v_gqa_ref) {
- LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, hparams.n_embd_v_gqa(), n_embd_v_gqa_ref);
+size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) {
+ llama_data_read_buffer data_ctx(src, size);
+ try {
+ return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what());
return 0;
}
-
- // Allocate the new cells for the slot
- if (cell_count) {
- llama_batch batch = llama_batch_init(cell_count, 0, 1);
- batch.n_tokens = cell_count;
- for (uint32_t i = 0; i < cell_count; ++i) {
- llama_pos pos;
- memcpy(&pos, inp, sizeof(pos));
- inp += sizeof(pos);
-
- batch.pos[i] = pos;
- batch.n_seq_id[i] = 1;
- batch.seq_id[i][0] = dest_seq_id;
- }
- if (!llama_kv_cache_find_slot(kv_self, batch)) {
- llama_batch_free(batch);
- LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
- return 0;
- }
-
- // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
- // Assume that this is one contiguous block of cells
- GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
- GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
- GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
- GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
- GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
-
- // Cleanup
- llama_batch_free(batch);
- }
-
- const uint32_t kv_size = kv_self.size;
- const uint32_t kv_head = kv_self.head;
-
- // For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo
- for (int il = 0; il < (int)n_layer; ++il) {
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
-
- // Read type of key
- int32_t k_type_i_ref;
- memcpy(&k_type_i_ref, inp, sizeof(k_type_i_ref));
- inp += sizeof(k_type_i_ref);
- const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
- if (k_type_i != k_type_i_ref) {
- llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
- LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
- return 0;
- }
-
- // Read row size of key
- size_t k_size_row_ref;
- memcpy(&k_size_row_ref, inp, sizeof(k_size_row_ref));
- inp += sizeof(k_size_row_ref);
- const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
- if (k_size_row != k_size_row_ref) {
- llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
- LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, k_size_row_ref, il);
- return 0;
- }
-
- if (cell_count) {
- // Read and set the keys for the whole cell range
- ggml_backend_tensor_set(kv_self.k_l[il], inp, kv_head * k_size_row, cell_count * k_size_row);
- inp += cell_count * k_size_row;
- }
- }
-
- // TODO: simplify, reduce copy-paste
- if (!kv_self.v_trans) {
- for (int il = 0; il < (int)n_layer; ++il) {
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
- // Read type of value
- int32_t v_type_i_ref;
- memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
- inp += sizeof(v_type_i_ref);
- const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
- if (v_type_i != v_type_i_ref) {
- llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
- LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
- return 0;
- }
-
- // Read row size of value
- size_t v_size_row_ref;
- memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref));
- inp += sizeof(v_size_row_ref);
- const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
- if (v_size_row != v_size_row_ref) {
- llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
- LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il);
- return 0;
- }
-
- if (cell_count) {
- // Read and set the values for the whole cell range
- ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row);
- inp += cell_count * v_size_row;
- }
- }
- } else {
- // For each layer, read the values for each cell (transposed)
- for (int il = 0; il < (int)n_layer; ++il) {
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
- // Read type of value
- int32_t v_type_i_ref;
- memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
- inp += sizeof(v_type_i_ref);
- const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
- if (v_type_i != v_type_i_ref) {
- llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
- LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
- return 0;
- }
-
- // Read element size of value
- size_t v_size_el_ref;
- memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref));
- inp += sizeof(v_size_el_ref);
- const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
- if (v_size_el != v_size_el_ref) {
- llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
- LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il);
- return 0;
- }
-
- if (cell_count) {
- // For each row in the transposed matrix, read the values for the whole cell range
- for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
- const size_t dst_offset = (kv_head + j * kv_size) * v_size_el;
- ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el);
- inp += cell_count * v_size_el;
- }
- }
- }
- }
-
- const size_t nread = inp - src;
-
- return nread;
}
static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
file.write_u32(LLAMA_STATE_SEQ_VERSION);
// save the prompt
- file.write_u32((uint32_t)n_token_count);
+ file.write_u32((uint32_t) n_token_count);
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
// save the context state using stream saving
- llama_data_file_context data_ctx(&file);
+ llama_data_write_file data_ctx(&file);
llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
const size_t res = file.tell();
// restore the context state
{
const size_t state_size = file.size - file.tell();
- std::vector<uint8_t> state_data(state_size);
- file.read_raw(state_data.data(), state_size);
- const size_t nread = llama_state_seq_set_data(ctx, state_data.data(), dest_seq_id);
+ llama_data_read_file data_ctx(&file);
+ const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
if (!nread) {
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
return 0;
try {
return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count);
} catch (const std::exception & err) {
- LLAMA_LOG_ERROR("error saving sequence state file: %s\n", err.what());
+ LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what());
return 0;
}
}
try {
return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out);
} catch (const std::exception & err) {
- LLAMA_LOG_ERROR("error loading sequence state file: %s\n", err.what());
+ LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what());
return 0;
}
}