From: Aman Gupta Date: Mon, 14 Jul 2025 13:01:41 +0000 (+0800) Subject: llama-context: add ability to get logits (#14672) X-Git-Tag: upstream/0.0.6073~178 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=9c9e4fc6354fc811efa06a8eb7a86d3315cec9c8;p=pkg%2Fggml%2Fsources%2Fllama.cpp llama-context: add ability to get logits (#14672) --- diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 06e93b19..7c07b047 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -731,7 +731,8 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; - const int64_t n_embd = hparams.n_embd; + const int64_t n_embd = hparams.n_embd; + const int32_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) { @@ -791,10 +792,20 @@ int llama_context::encode(const llama_batch & batch_inp) { } } + auto * t_logits = res->get_logits(); auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + // extract logits + if (logits && t_logits) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float)); + } + // extract embeddings - if (t_embd) { + if (embd && t_embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr);