}
};
+// temporary allocate memory for the input batch if needed
+static const llama_seq_id batch_default_seq_id = 0;
+struct llama_batch_allocr {
+ std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
+ std::vector<llama_pos> pos;
+ std::vector<int32_t> n_seq_id;
+ std::vector<llama_seq_id *> seq_id;
+ std::vector<int8_t> logits;
+ struct llama_batch batch;
+ // optionally fulfill the batch returned by llama_batch_get_one
+ llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) {
+ batch = in_batch;
+ GGML_ASSERT(batch.n_tokens > 0);
+ if (!batch.pos) {
+ // determine the last position in KV cache
+ llama_pos last_pos = -1;
+ for (const auto & cell : ctx.kv_self.cells) {
+ if (cell.has_seq_id(batch_default_seq_id)) {
+ last_pos = std::max(last_pos, cell.pos);
+ }
+ }
+ last_pos++; // next position
+ pos.resize(batch.n_tokens);
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
+ pos[i] = i+last_pos;
+ }
+ batch.pos = pos.data();
+ }
+ if (!batch.n_seq_id) {
+ n_seq_id.resize(batch.n_tokens);
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
+ n_seq_id[i] = seq_id_0.size();
+ }
+ batch.n_seq_id = n_seq_id.data();
+ }
+ if (!batch.seq_id) {
+ seq_id.resize(batch.n_tokens + 1);
+ seq_id[batch.n_tokens] = NULL;
+ for (int32_t i = 0; i < batch.n_tokens; i++) {
+ seq_id[i] = seq_id_0.data();
+ }
+ batch.seq_id = seq_id.data();
+ }
+ if (!batch.logits) {
+ logits.resize(batch.n_tokens);
+ logits[logits.size() - 1] = true;
+ batch.logits = logits.data();
+ }
+ }
+};
+
template<>
bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
uint32_t tmp;
//
static int llama_decode_internal(
llama_context & lctx,
- llama_batch batch) {
+ llama_batch inp_batch) {
lctx.is_encoding = false;
- const uint32_t n_tokens_all = batch.n_tokens;
- if (n_tokens_all == 0) {
+ if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
+ // temporary allocate memory for the input batch if needed
+ llama_batch_allocr batch_allocr(lctx, inp_batch);
+ const llama_batch & batch = batch_allocr.batch;
+ const uint32_t n_tokens_all = batch.n_tokens;
+
const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
//
static int llama_encode_internal(
llama_context & lctx,
- llama_batch batch) {
+ llama_batch inp_batch) {
lctx.is_encoding = true;
- const uint32_t n_tokens = batch.n_tokens;
-
- if (n_tokens == 0) {
+ if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
+ // temporary allocate memory for the input batch if needed
+ llama_batch_allocr batch_allocr(lctx, inp_batch);
+ const llama_batch & batch = batch_allocr.batch;
+ const uint32_t n_tokens = batch.n_tokens;
+
const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
if (batch.logits) free(batch.logits);
}
-// temporary allocate memory for the input batch if needed
-static const llama_seq_id batch_default_seq_id = 0;
-struct llama_batch_allocr {
- std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
- std::vector<llama_pos> pos;
- std::vector<int32_t> n_seq_id;
- std::vector<llama_seq_id *> seq_id;
- std::vector<int8_t> logits;
- struct llama_batch batch;
- // optionally fulfill the batch returned by llama_batch_get_one
- llama_batch_allocr(struct llama_context * ctx, struct llama_batch in_batch) {
- batch = in_batch;
- if (!batch.pos) {
- // determine the last position in KV cache
- llama_pos last_pos = -1;
- for (const auto & cell : ctx->kv_self.cells) {
- if (cell.has_seq_id(batch_default_seq_id)) {
- last_pos = std::max(last_pos, cell.pos);
- }
- }
- last_pos++; // next position
- pos.resize(batch.n_tokens);
- for (int32_t i = 0; i < batch.n_tokens; i++) {
- pos[i] = i+last_pos;
- }
- batch.pos = pos.data();
- }
- if (!batch.n_seq_id) {
- n_seq_id.resize(batch.n_tokens);
- for (int32_t i = 0; i < batch.n_tokens; i++) {
- n_seq_id[i] = seq_id_0.size();
- }
- batch.n_seq_id = n_seq_id.data();
- }
- if (!batch.seq_id) {
- seq_id.resize(batch.n_tokens + 1);
- seq_id[batch.n_tokens] = NULL;
- for (int32_t i = 0; i < batch.n_tokens; i++) {
- seq_id[i] = seq_id_0.data();
- }
- batch.seq_id = seq_id.data();
- }
- if (!batch.logits) {
- logits.resize(batch.n_tokens);
- logits[logits.size() - 1] = true;
- batch.logits = logits.data();
- }
- }
-};
-
int32_t llama_encode(
struct llama_context * ctx,
struct llama_batch batch) {
- llama_batch_allocr batch_allocr(ctx, batch);
- const int ret = llama_encode_internal(*ctx, batch_allocr.batch);
+ const int ret = llama_encode_internal(*ctx, batch);
if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
}
int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch) {
- llama_batch_allocr batch_allocr(ctx, batch);
- const int ret = llama_decode_internal(*ctx, batch_allocr.batch);
+ const int ret = llama_decode_internal(*ctx, batch);
if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}