ubatch.seq_id = batch->seq_id + seq.offset;
}
}
- if (logits_all) {
- for (size_t i = 0; i < length; ++i) {
- ubatch.output[ubatch.n_tokens + i] = 1;
- out_ids.push_back(ids[seq.offset + i]);
- }
- } else if (batch->logits) {
+ if (batch->logits) {
if (ubatch.equal_seqs) {
for (size_t i = 0; i < length; ++i) {
size_t id = ids[seq.offset + i];
return ubatch;
}
-llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
+llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
GGML_ASSERT(batch.n_tokens >= 0);
this->batch = &batch;
this->n_embd = n_embd;
- this->logits_all = logits_all;
n_tokens = batch.n_tokens;
ids.resize(n_tokens);
size_t n_embd;
- bool logits_all; // TODO: remove once lctx.logits_all is removed too
-
// sorted indices into the batch
std::vector<int64_t> ids;
// batch indices of the output
llama_ubatch split_seq(size_t n_ubatch);
llama_sbatch() = default;
- llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
+ llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
};
// temporary allocate memory for the input batch if needed
const int64_t n_embd = hparams.n_embd;
- llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
+ llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
llama_memory_state_ptr mstate;
while (true) {
- mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
+ mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
if (!mstate) {
return -2;
}
int64_t n_outputs_all = n_tokens_all;
- auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
+ auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
break;
return result;
}
-llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
+llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
GGML_UNUSED(embd_pooled);
- auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
std::vector<llama_ubatch> ubatches;
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
- bool embd_pooled,
- bool logits_all) override;
+ bool embd_pooled) override;
llama_memory_state_ptr init_full() override;
return kv_swa->seq_pos_max(seq_id);
}
-llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
+llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
GGML_UNUSED(embd_pooled);
// first try simple split
do {
- auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
std::vector<llama_ubatch> ubatches;
// if it fails, try equal split
do {
- auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
std::vector<llama_ubatch> ubatches;
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
- bool embd_pooled,
- bool logits_all) override;
+ bool embd_pooled) override;
llama_memory_state_ptr init_full() override;
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
- bool embd_pooled,
- bool logits_all) {
+ bool embd_pooled) {
GGML_UNUSED(embd_pooled);
do {
- auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
+ auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
std::vector<llama_ubatch> ubatches;
while (sbatch.n_tokens > 0) {
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
- bool embd_pooled,
- bool logits_all) override;
+ bool embd_pooled) override;
llama_memory_state_ptr init_full() override;
virtual llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
- bool embd_pooled,
- bool logits_all) = 0;
+ bool embd_pooled) = 0;
// simulate full cache, used for allocating worst-case compute buffers
virtual llama_memory_state_ptr init_full() = 0;