clear();
split_reset();
+ const int64_t n_pos_all = (int64_t) n_tokens*n_pos_per_embd;
+
auto udata = std::make_shared<llama_ubatch::data_t>();
udata->token .resize(n_tokens);
udata->embd .clear();
- udata->pos .resize(n_tokens);
+ udata->pos .resize(n_pos_all);
udata->n_seq_id .resize(n_tokens);
udata->seq_id .resize(n_tokens);
udata->seq_id_unq.resize(0);
io.write(&pos, sizeof(pos));
io.write(&n_seq_id, sizeof(n_seq_id));
- // TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
- // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
+ if (hparams.n_pos_per_embd() > 1) {
+ const llama_kv_cell_ext ext = cells.ext_get(i);
+ io.write(&ext, sizeof(ext));
+ }
for (const auto & seq_id : seq_ids) {
io.write(&seq_id, sizeof(seq_id));
return false;
}
+ if (hparams.n_pos_per_embd() > 1) {
+ llama_kv_cell_ext ext;
+ io.read_to(&ext, sizeof(ext));
+
+ ubatch.pos[i + ubatch.n_tokens] = ext.y;
+ ubatch.pos[i + ubatch.n_tokens*2] = ext.x;
+ }
+
// read the sequence id, but directly discard it - we will use dest_seq_id instead
{
llama_seq_id seq_id;