const auto & hparams = model.hparams;
- const int64_t n_embd = hparams.n_embd;
+ const int64_t n_embd = hparams.n_embd;
+ const int32_t n_vocab = model.vocab.n_tokens();
// note: during encode, we always pass the full sequence starting from pos = 0
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
}
}
+ auto * t_logits = res->get_logits();
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
+ // extract logits
+ if (logits && t_logits) {
+ ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
+ GGML_ASSERT(backend_res != nullptr);
+ GGML_ASSERT(logits != nullptr);
+
+ ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
+ }
+
// extract embeddings
- if (t_embd) {
+ if (embd && t_embd) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
GGML_ASSERT(backend_embd != nullptr);