Splits producing more than one ubatch per batch for recurrent models
were broken with #14512.
This fixes it by moving the completeness check after the ubatch split loop.
ubatch = balloc.split_equal(n_ubatch, false);
}
- if (balloc.get_n_used() < balloc.get_n_tokens()) {
- // failed to find a suitable split
+ if (ubatch.n_tokens == 0) {
break;
}
ubatches.push_back(std::move(ubatch)); // NOLINT
}
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
+ // failed to find a suitable split
+ break;
+ }
+
if (!prepare(ubatches)) {
break;
}