]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama-context: add ability to get logits (#14672)
authorAman Gupta <redacted>
Mon, 14 Jul 2025 13:01:41 +0000 (21:01 +0800)
committerGitHub <redacted>
Mon, 14 Jul 2025 13:01:41 +0000 (21:01 +0800)
src/llama-context.cpp

index 06e93b19cbf4087b24284ea3a473f12441c31b6a..7c07b047b0dd9a6faaff2282f9c3acb1a0acd8b2 100644 (file)
@@ -731,7 +731,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
 
     const auto & hparams = model.hparams;
 
-    const int64_t n_embd = hparams.n_embd;
+    const int64_t n_embd  = hparams.n_embd;
+    const int32_t n_vocab = model.vocab.n_tokens();
 
     // note: during encode, we always pass the full sequence starting from pos = 0
     if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
@@ -791,10 +792,20 @@ int llama_context::encode(const llama_batch & batch_inp) {
         }
     }
 
+    auto * t_logits = res->get_logits();
     auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
 
+    // extract logits
+   if (logits && t_logits) {
+        ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
+        GGML_ASSERT(backend_res != nullptr);
+        GGML_ASSERT(logits != nullptr);
+
+        ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
+    }
+
     // extract embeddings
-    if (t_embd) {
+    if (embd && t_embd) {
         ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
         GGML_ASSERT(backend_embd != nullptr);