return n_outputs;
}
+uint32_t llama_batch_allocr::get_n_used() const {
+ return n_used;
+}
+
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
return out_ids;
}
void llama_batch_allocr::split_reset() {
out_ids.clear();
+ n_used = 0;
+
used.clear();
used.resize(get_n_tokens(), false);
idxs.push_back(cur_idx);
used[cur_idx] = true;
+ ++n_used;
++cur_idx;
idxs_per_seq[s].push_back(idx);
used[idx] = true;
+ ++n_used;
++cur_idx[s];
}
idxs.push_back(cur_idx);
used[cur_idx] = true;
+ ++n_used;
if (idxs.size() >= n_ubatch) {
break;
uint32_t get_n_tokens() const;
uint32_t get_n_outputs() const;
+ uint32_t get_n_used() const;
// the array of output indices in the order they were encountered during the ubatch splitting
std::vector<int32_t> & get_out_ids();
// batch indices of the output
std::vector<int32_t> out_ids;
+ uint32_t n_used;
+
// used[i] indicates if token i has already been used in a previous ubatch
std::vector<bool> used;
ubatches.push_back(std::move(ubatch)); // NOLINT
}
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
+ // failed to find a suitable split
+ break;
+ }
+
auto sinfos_base = kv_base->prepare(ubatches);
if (sinfos_base.empty()) {
break;
ubatches.push_back(std::move(ubatch)); // NOLINT
}
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
+ // failed to find a suitable split
+ break;
+ }
+
auto sinfos_base = kv_base->prepare(ubatches);
if (sinfos_base.empty()) {
break;
ubatches.push_back(std::move(ubatch)); // NOLINT
}
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
+ // failed to find a suitable split
+ break;
+ }
+
auto sinfos = prepare(ubatches);
if (sinfos.empty()) {
break;
ubatches.push_back(std::move(ubatch)); // NOLINT
}
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
+ // failed to find a suitable split
+ break;
+ }
+
// prepare the recurrent batches first
if (!mem_recr->prepare(ubatches)) {
// TODO: will the recurrent cache be in an undefined context at this point?
ubatch = balloc.split_equal(n_ubatch);
}
- if (ubatch.n_tokens == 0) {
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
+ // failed to find a suitable split
break;
}