batch.seq_id = seq_id.data();
}
if (!batch.logits) {
- logits.resize(batch.n_tokens);
- logits[logits.size() - 1] = true;
- batch.logits = logits.data();
+ // by default return the output only for the last token
+ output.resize(batch.n_tokens);
+ output[output.size() - 1] = true;
+ batch.logits = output.data();
}
}
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
- std::vector<int8_t> logits;
+ std::vector<int8_t> output;
// optionally fulfill the batch returned by llama_batch_get_one
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
t_compute_start_us = ggml_time_us();
}
+ // TODO: this clear of the buffer can easily be forgotten - need something better
embd_seq.clear();
n_queued_tokens += n_tokens;
}
}
+ // this indicates we are doing pooled embedding
+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
+
+ int64_t n_outputs_all = 0;
+
+ // count outputs
+ for (uint32_t i = 0; i < n_tokens_all; ++i) {
+ n_outputs_all += batch.logits[i] != 0;
+ }
+
+ if (embd_pooled) {
+ // require that all tokens are output
+ if (n_outputs_all != n_tokens_all) {
+ LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 ", n_tokens_all = %" PRId64 ")\n",
+ __func__, n_outputs_all, n_tokens_all);
+ return -1;
+ }
+ }
+
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
}
n_queued_tokens += n_tokens_all;
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
-
+ // TODO: this clear of the buffer can easily be forgotten - need something better
embd_seq.clear();
- int64_t n_outputs_all = 0;
-
- // count outputs
- if (batch.logits && !embd_pooled) {
- for (uint32_t i = 0; i < n_tokens_all; ++i) {
- n_outputs_all += batch.logits[i] != 0;
- }
- } else if (embd_pooled) {
- n_outputs_all = n_tokens_all;
- } else {
- // keep last output only
- n_outputs_all = 1;
- }
-
bool did_optimize = false;
// handle any pending defrags/shifts
do {
const auto & ubatch = mstate->get_ubatch();
- // count the outputs in this u_batch
+ // count the outputs in this ubatch
{
int32_t n_outputs_new = 0;
n_queued_tokens += n_tokens_all;
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
+ // this indicates we are doing pooled embedding
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
embd_seq.clear();