#include "llama-context.h"
#include "llama-impl.h"
+#include "llama-batch.h"
#include "llama-io.h"
#include "llama-memory.h"
#include "llama-mmap.h"
llama_context::llama_context(
const llama_model & model,
llama_context_params params) :
- model(model) {
+ model(model),
+ batch_allocr(std::make_unique<llama_batch_allocr>()) {
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
t_start_us = model.t_start_us;
}
float * llama_context::get_logits_ith(int32_t i) {
- int32_t j = -1;
+ int64_t j = -1;
try {
if (logits == nullptr) {
}
if (j >= n_outputs) {
// This should not happen
- throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
}
return logits + j*model.vocab.n_tokens();
}
float * llama_context::get_embeddings_ith(int32_t i) {
- int32_t j = -1;
+ int64_t j = -1;
try {
if (embd == nullptr) {
}
if (j >= n_outputs) {
// This should not happen
- throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
+ throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
}
return embd + j*model.hparams.n_embd;
return res;
}
-int llama_context::encode(llama_batch & inp_batch) {
- if (inp_batch.n_tokens == 0) {
+int llama_context::encode(const llama_batch & batch_inp) {
+ if (batch_inp.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
// temporary allocate memory for the input batch if needed
// note: during encode, we always pass the full sequence starting from pos = 0
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
+ if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : 0)) {
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
+ return -1;
+ }
- const llama_batch & batch = batch_allocr.batch;
- const int32_t n_tokens = batch.n_tokens;
+ const llama_batch & batch = batch_allocr->get_batch();
- const auto & hparams = model.hparams;
+ const uint32_t n_tokens = batch.n_tokens;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
- // TODO: move the validation to the llama_batch_allocr
- if (batch.token) {
- for (int32_t i = 0; i < n_tokens; ++i) {
- if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
- LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
- return -1;
- }
-
- if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
- LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
- throw -1;
- }
- }
- }
-
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
- GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
+ GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
if (t_compute_start_us == 0) {
t_compute_start_us = ggml_time_us();
n_queued_tokens += n_tokens;
+ const auto & hparams = model.hparams;
+
const int64_t n_embd = hparams.n_embd;
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
return -2;
};
- for (int32_t i = 0; i < n_tokens; ++i) {
+ for (uint32_t i = 0; i < n_tokens; ++i) {
output_ids[i] = i;
}
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
- for (int32_t i = 0; i < n_tokens; i++) {
+ // TODO: fix indexing [UBATCH_IDX]
+ for (uint32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = ubatch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
auto & embd_seq_out = embd_seq;
const uint32_t n_cls_out = hparams.n_cls_out;
+ // TODO: fix indexing [UBATCH_IDX]
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
// remember the sequence ids used during the encoding - needed for cross attention later
- // TODO: the seuqence indexing here is likely not correct in the general case
- // probably works only for split_simple
cross.seq_ids_enc.resize(n_tokens);
- for (int32_t i = 0; i < n_tokens; i++) {
+ for (uint32_t i = 0; i < n_tokens; i++) {
cross.seq_ids_enc[i].clear();
- for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
- llama_seq_id seq_id = ubatch.seq_id[i][s];
+ for (int s = 0; s < batch.n_seq_id[i]; s++) {
+ llama_seq_id seq_id = batch.seq_id[i][s];
cross.seq_ids_enc[i].insert(seq_id);
}
}
return 0;
}
-int llama_context::decode(llama_batch & inp_batch) {
+int llama_context::decode(const llama_batch & batch_inp) {
if (!memory) {
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
- return encode(inp_batch);
+ return encode(batch_inp);
}
- if (inp_batch.n_tokens == 0) {
+ if (batch_inp.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
- if (!inp_batch.pos) {
- if (inp_batch.seq_id) {
- LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
- return -1;
- }
- }
-
// temporary allocate memory for the input batch if needed
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1);
+ if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : memory->seq_pos_max(0) + 1)) {
+ LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
+ return -1;
+ }
- const llama_batch & batch = batch_allocr.batch;
+ const llama_batch & batch = batch_allocr->get_batch();
const auto & vocab = model.vocab;
const auto & hparams = model.hparams;
const int32_t n_vocab = vocab.n_tokens();
+ const int64_t n_embd = hparams.n_embd;
- const int64_t n_tokens_all = batch.n_tokens;
- const int64_t n_embd = hparams.n_embd;
+ const uint32_t n_tokens_all = batch.n_tokens;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
- // TODO: move the validation to the llama_batch_allocr
- if (batch.token) {
- for (int64_t i = 0; i < n_tokens_all; ++i) {
- if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
- LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
- return -1;
- }
-
- if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
- LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
- return -1;
- }
- }
- }
-
// this indicates we are doing pooled embedding
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
- int64_t n_outputs_all = 0;
-
- // count outputs
- for (uint32_t i = 0; i < n_tokens_all; ++i) {
- n_outputs_all += batch.logits[i] != 0;
- }
+ const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
if (embd_pooled) {
// require that all tokens are output
if (n_outputs_all != n_tokens_all) {
- LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 ", n_tokens_all = %" PRId64 ")\n",
+ LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
__func__, n_outputs_all, n_tokens_all);
return -1;
}
// reserve output buffer
if (output_reserve(n_outputs_all) < n_outputs_all) {
- LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
return -2;
};
pos_min[s] = std::numeric_limits<llama_pos>::max();
}
+ // TODO: fix sequence indexing
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
const auto & seq_id = ubatch.seq_id[i][0];
n_outputs = n_outputs_all;
// set output mappings
- {
+ if (n_outputs > 0) {
bool sorted_output = true;
auto & out_ids = mstate->out_ids();
- GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
+ GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
- for (int64_t i = 0; i < n_outputs_all; ++i) {
+ for (int64_t i = 0; i < n_outputs; ++i) {
int64_t out_id = out_ids[i];
output_ids[out_id] = i;
if (out_id != i) {
// note: this is mostly relevant for recurrent models atm
if (!sorted_output) {
const uint32_t n_vocab = model.vocab.n_tokens();
- const uint32_t n_embd = model.hparams.n_embd;
+ const uint64_t n_embd = model.hparams.n_embd;
GGML_ASSERT((size_t) n_outputs == out_ids.size());
// TODO: is there something more efficient which also minimizes swaps?
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
- for (int32_t i = 0; i < n_outputs - 1; ++i) {
- int32_t j_min = i;
- for (int32_t j = i + 1; j < n_outputs; ++j) {
+ for (uint32_t i = 0; i < n_outputs - 1; ++i) {
+ uint32_t j_min = i;
+ for (uint32_t j = i + 1; j < n_outputs; ++j) {
if (out_ids[j] < out_ids[j_min]) {
j_min = j;
}
}
- if (j_min == i) { continue; }
+ if (j_min == i) {
+ continue;
+ }
std::swap(out_ids[i], out_ids[j_min]);
if (logits_size > 0) {
for (uint32_t k = 0; k < n_vocab; k++) {
}
}
}
+
std::fill(output_ids.begin(), output_ids.end(), -1);
- for (int32_t i = 0; i < n_outputs; ++i) {
+
+ for (uint32_t i = 0; i < n_outputs; ++i) {
output_ids[out_ids[i]] = i;
}
}
// output
//
-int32_t llama_context::output_reserve(int32_t n_outputs) {
+uint32_t llama_context::output_reserve(int32_t n_outputs) {
const auto & hparams = model.hparams;
const auto & vocab = model.vocab;
// set all ids as invalid (negative)
std::fill(output_ids.begin(), output_ids.end(), -1);
- this->n_outputs = 0;
- this->n_outputs_max = n_outputs_max;
+ this->n_outputs = 0;
return n_outputs_max;
}
std::vector<int32_t> w_output_pos;
- GGML_ASSERT(n_outputs <= n_outputs_max);
-
w_output_pos.resize(n_outputs);
// build a more compact representation of the output ids
for (size_t i = 0; i < n_batch(); ++i) {
// map an output id to a position in the batch
- int32_t pos = output_ids[i];
+ int64_t pos = output_ids[i];
if (pos >= 0) {
GGML_ASSERT(pos < n_outputs);
w_output_pos[pos] = i;
embd_seq.clear();
- int64_t n_outputs_all = n_tokens_all;
+ uint32_t n_outputs_all = n_tokens_all;
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
// reserve output buffer
if (output_reserve(n_outputs_all) < n_outputs_all) {
- LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
GGML_ABORT("TODO: handle this error");
};