}
llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
- std::vector<llama_ubatch> ubatches;
+ do {
+ balloc.split_reset();
- while (true) {
- llama_ubatch ubatch;
+ std::vector<llama_ubatch> ubatches;
+ while (true) {
+ llama_ubatch ubatch;
- if (embd_all) {
- // if all tokens are output, split by sequence
- ubatch = balloc.split_seq(n_ubatch);
- } else {
- ubatch = balloc.split_equal(n_ubatch);
+ if (embd_all) {
+ // if all tokens are output, split by sequence
+ ubatch = balloc.split_seq(n_ubatch);
+ } else {
+ ubatch = balloc.split_equal(n_ubatch);
+ }
+
+ if (ubatch.n_tokens == 0) {
+ break;
+ }
+
+ ubatches.push_back(std::move(ubatch)); // NOLINT
}
- if (ubatch.n_tokens == 0) {
+ if (!prepare(ubatches)) {
break;
}
- ubatches.push_back(std::move(ubatch)); // NOLINT
- }
-
- if (!prepare(ubatches)) {
- return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
- }
+ return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
+ } while (false);
- return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
+ return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
llama_memory_context_ptr llama_memory_recurrent::init_full() {