llama_batch get_view(int offset, int n_tokens) {
llama_pos * pos_ptr;
pos_view.clear();
- pos_view.resize(n_tokens * n_pos_per_embd);
+ pos_view.reserve(n_tokens * n_pos_per_embd);
if (n_pos_per_embd > 1) {
// mrope
// for example, with layout of src: 1234...1234...1234...1234...
// offset 2 will give us dst: 34...34...34...34...
for (int i = 0; i < n_pos_per_embd; i++) {
- auto src = pos.begin() + i * batch.n_tokens + offset;
- pos_view.insert(pos_view.end(), src, src + n_tokens);
+ // assume n_tokens is less than or equal to batch.n_tokens
+ // batch.n_tokens is number of **total** tokens
+ // n_tokens is number of viewed token
+ size_t src_idx = i * batch.n_tokens + offset;
+ pos_view.insert(pos_view.end(),
+ pos.data() + src_idx,
+ pos.data() + src_idx + n_tokens);
}
pos_ptr = pos_view.data();
} else {