return status;
}
-// decode a batch of tokens by evaluating the transformer
-// in case of unsuccessful decoding (error or warning),
-// the kv_cache state will be returned to its original state
-// (for non-recurrent models) or cleaned (for recurrent models)
-//
-// - lctx: llama context
-// - batch: batch to evaluate
-//
-// return 0 on success
-// return positive int on warning
-// return negative int on error
-//
-static int llama_decode_impl(
- llama_context & lctx,
- llama_batch inp_batch) {
-
- lctx.is_encoding = false;
-
- if (inp_batch.n_tokens == 0) {
- LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
- return -1;
- }
-
- // temporary allocate memory for the input batch if needed
- llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
-
- const llama_batch & batch = batch_allocr.batch;
- const uint32_t n_tokens_all = batch.n_tokens;
-
+static int llama_prepare_sbatch(
+ llama_context & lctx,
+ const llama_batch & batch,
+ uint32_t & n_outputs) {
const auto & model = lctx.model;
- const auto & vocab = model.vocab;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
- GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
+ const uint32_t n_tokens_all = batch.n_tokens;
+ const int64_t n_embd = hparams.n_embd;
+
+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
+ GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (batch.token) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
- if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
+ if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1;
}
}
}
-
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
-
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
- if (lctx.t_compute_start_us == 0) {
- lctx.t_compute_start_us = ggml_time_us();
- }
lctx.n_queued_tokens += n_tokens_all;
-
- auto & kv_self = lctx.kv_self;
- llama_kv_slot_restorer kv_slot_restorer(kv_self);
-
- const int64_t n_embd = hparams.n_embd;
- const int64_t n_vocab = vocab.n_tokens();
-
- uint32_t n_outputs = 0;
- uint32_t n_outputs_prev = 0;
-
- const auto n_ubatch = cparams.n_ubatch;
-
- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
-
lctx.embd_seq.clear();
// count outputs
}
lctx.sbatch.from_batch(batch, n_embd,
- /* simple_split */ !kv_self.recurrent,
+ /* simple_split */ !lctx.kv_self.recurrent,
/* logits_all */ n_outputs == n_tokens_all);
// reserve output buffer
return -2;
};
- while (lctx.sbatch.n_tokens > 0) {
- llama_ubatch ubatch;
- if (kv_self.recurrent) {
- if (embd_pooled) {
- // Pooled embeddings cannot be split across ubatches (yet)
- ubatch = lctx.sbatch.split_seq(n_ubatch);
- } else {
- // recurrent model architectures are easier to implement
- // with equal-length sequences
- ubatch = lctx.sbatch.split_equal(n_ubatch);
- }
+ return 0;
+}
+
+static int llama_prepare_ubatch(
+ llama_context & lctx,
+ llama_kv_slot_restorer & kv_slot_restorer,
+ llama_ubatch & ubatch,
+ const uint32_t n_outputs,
+ const uint32_t n_tokens_all) {
+ GGML_ASSERT(lctx.sbatch.n_tokens > 0);
+
+ auto & kv_self = lctx.kv_self;
+ const auto & cparams = lctx.cparams;
+ const auto & hparams = lctx.model.hparams;
+
+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
+
+ if (lctx.kv_self.recurrent) {
+ if (embd_pooled) {
+ // Pooled embeddings cannot be split across ubatches (yet)
+ ubatch = lctx.sbatch.split_seq(cparams.n_ubatch);
} else {
- ubatch = lctx.sbatch.split_simple(n_ubatch);
+ // recurrent model architectures are easier to implement
+ // with equal-length sequences
+ ubatch = lctx.sbatch.split_equal(cparams.n_ubatch);
}
- const uint32_t n_tokens = ubatch.n_tokens;
+ } else {
+ ubatch = lctx.sbatch.split_simple(cparams.n_ubatch);
+ }
- // count the outputs in this u_batch
- {
- int32_t n_outputs_new = 0;
+ // count the outputs in this u_batch
+ {
+ int32_t n_outputs_new = 0;
- if (n_outputs == n_tokens_all) {
- n_outputs_new = n_tokens;
- } else {
- GGML_ASSERT(ubatch.output);
- for (uint32_t i = 0; i < n_tokens; i++) {
- n_outputs_new += (int32_t) (ubatch.output[i] != 0);
- }
+ if (n_outputs == n_tokens_all) {
+ n_outputs_new = ubatch.n_tokens;
+ } else {
+ GGML_ASSERT(ubatch.output);
+ for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
+ n_outputs_new += int32_t(ubatch.output[i] != 0);
}
+ }
+
+ // needs to happen before the graph is built
+ lctx.n_outputs = n_outputs_new;
+ }
+
+ // non-causal masks do not use the KV cache
+ if (hparams.causal_attn) {
+ llama_kv_cache_update(&lctx);
- // needs to happen before the graph is built
- lctx.n_outputs = n_outputs_new;
+ // if we have enough unused cells before the current head ->
+ // better to start searching from the beginning of the cache, hoping to fill it
+ if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
+ kv_self.head = 0;
}
- int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
- ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
+ const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
+ if (!slot) {
+ return 1;
+ }
+ kv_slot_restorer.save(slot);
- GGML_ASSERT(n_threads > 0);
+ if (!kv_self.recurrent) {
+ // a heuristic, to avoid attending the full cache if it is not yet utilized
+ // after enough generations, the benefit from this heuristic disappears
+ // if we start defragmenting the cache, the benefit from this will be more important
+ const uint32_t pad = llama_kv_cache_get_padding(cparams);
+ kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
+ //kv_self.n = llama_kv_cache_cell_max(kv_self);
+ }
+ }
- // non-causal masks do not use the KV cache
- if (hparams.causal_attn) {
- llama_kv_cache_update(&lctx);
+ return 0;
+}
- // if we have enough unused cells before the current head ->
- // better to start searching from the beginning of the cache, hoping to fill it
- if (kv_self.head > kv_self.used + 2*n_tokens) {
- kv_self.head = 0;
- }
+// decode a batch of tokens by evaluating the transformer
+// in case of unsuccessful decoding (error or warning),
+// the kv_cache state will be returned to its original state
+// (for non-recurrent models) or cleaned (for recurrent models)
+//
+// - lctx: llama context
+// - inp_batch: batch to evaluate
+//
+// return 0 on success
+// return positive int on warning
+// return negative int on error
+//
+static int llama_decode_impl(
+ llama_context & lctx,
+ llama_batch inp_batch) {
- const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
- if (!slot) {
- return 1;
- }
- kv_slot_restorer.save(slot);
+ lctx.is_encoding = false;
- if (!kv_self.recurrent) {
- // a heuristic, to avoid attending the full cache if it is not yet utilized
- // after enough generations, the benefit from this heuristic disappears
- // if we start defragmenting the cache, the benefit from this will be more important
- const uint32_t pad = llama_kv_cache_get_padding(cparams);
- kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
- //kv_self.n = llama_kv_cache_cell_max(kv_self);
+ if (inp_batch.n_tokens == 0) {
+ LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
+ return -1;
+ }
+
+ // temporarily allocate memory for the input batch if needed
+ llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
+ const llama_batch & batch = batch_allocr.batch;
+
+ const auto & model = lctx.model;
+ const auto & vocab = model.vocab;
+ const auto & hparams = model.hparams;
+ const auto & cparams = lctx.cparams;
+
+ if (lctx.t_compute_start_us == 0) {
+ lctx.t_compute_start_us = ggml_time_us();
+ }
+ auto & kv_self = lctx.kv_self;
+ llama_kv_slot_restorer kv_slot_restorer(kv_self);
+
+ const int64_t n_embd = hparams.n_embd;
+ const int64_t n_vocab = vocab.n_tokens();
+
+ uint32_t n_outputs = 0;
+ uint32_t n_outputs_prev = 0;
+
+ {
+ const int ret = llama_prepare_sbatch(lctx, batch, n_outputs);
+ if (ret != 0) {
+ return ret;
+ }
+ }
+
+ while (lctx.sbatch.n_tokens > 0) {
+ llama_ubatch ubatch;
+ {
+ const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens);
+ if (ret != 0) {
+ return ret;
}
}
+ const int n_threads = ubatch.n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
+ ggml_threadpool_t threadpool = ubatch.n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
+
+ GGML_ASSERT(n_threads > 0);
+
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
ggml_backend_sched_reset(lctx.sched.get());
// update the kv ring buffer
{
- kv_self.head += n_tokens;
+ kv_self.head += ubatch.n_tokens;
// Ensure kv cache head points to a valid index.
if (kv_self.head >= kv_self.size) {