]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : use output_resolve_row() in get_logits_ith/get_embeddings_ith (#19663)
authorDaniel Bevenius <redacted>
Thu, 19 Feb 2026 08:48:08 +0000 (09:48 +0100)
committerGitHub <redacted>
Thu, 19 Feb 2026 08:48:08 +0000 (09:48 +0100)
This commit updates get_logits_ith(), and get_embeddings_ith() to use
output_resolve_row() to resolve the batch index to output row index.

The motivation for this is to remove some code duplication between these
functions.

src/llama-context.cpp

index 7f4b4a933eaf37d36974077ccccf3896aeaa1bc3..7cd0bfc0d2405e3176a6138d614f60ecce9f6a64 100644 (file)
@@ -712,8 +712,6 @@ int64_t llama_context::output_resolve_row(int32_t i) const {
 }
 
 float * llama_context::get_logits_ith(int32_t i) {
-    int64_t j = -1;
-
     output_reorder();
 
     try {
@@ -721,26 +719,7 @@ float * llama_context::get_logits_ith(int32_t i) {
             throw std::runtime_error("no logits");
         }
 
-        // TODO: use output_resolve_row()
-        if (i < 0) {
-            j = n_outputs + i;
-            if (j < 0) {
-                throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
-            }
-        } else if ((size_t) i >= output_ids.size()) {
-            throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
-        } else {
-            j = output_ids[i];
-        }
-
-        if (j < 0) {
-            throw std::runtime_error(format("batch.logits[%d] != true", i));
-        }
-        if (j >= n_outputs) {
-            // This should not happen
-            throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
-        }
-
+        const int64_t j = output_resolve_row(i);
         return logits.data + j*model.vocab.n_tokens();
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
@@ -763,8 +742,6 @@ llama_token * llama_context::get_sampled_tokens()  const{
 }
 
 float * llama_context::get_embeddings_ith(int32_t i) {
-    int64_t j = -1;
-
     output_reorder();
 
     try {
@@ -772,26 +749,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
             throw std::runtime_error("no embeddings");
         }
 
-        // TODO: use output_resolve_row()
-        if (i < 0) {
-            j = n_outputs + i;
-            if (j < 0) {
-                throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
-            }
-        } else if ((size_t) i >= output_ids.size()) {
-            throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
-        } else {
-            j = output_ids[i];
-        }
-
-        if (j < 0) {
-            throw std::runtime_error(format("batch.logits[%d] != true", i));
-        }
-        if (j >= n_outputs) {
-            // This should not happen
-            throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
-        }
-
+        const int64_t j = output_resolve_row(i);
         const uint32_t n_embd_out = model.hparams.n_embd_out();
         return embd.data + j*n_embd_out;
     } catch (const std::exception & err) {