]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
context : perform output reorder lazily upon access after sync (#14853)
authorGeorgi Gerganov <redacted>
Thu, 24 Jul 2025 13:31:48 +0000 (16:31 +0300)
committerGitHub <redacted>
Thu, 24 Jul 2025 13:31:48 +0000 (16:31 +0300)
* context : perform output reorder after lazily upon access after sync

ggml-ci

* cont : add TODO

include/llama.h
src/llama-context.cpp
src/llama-context.h

index 1c3a1cd1b4e7de305bcba2d9a7fab1bedd5ed699..6f454a508a06c80bb92fab68f97f06e6cb95ccd5 100644 (file)
@@ -956,6 +956,7 @@ extern "C" {
     // in the order they have appeared in the batch.
     // Rows: number of tokens for which llama_batch.logits[i] != 0
     // Cols: n_vocab
+    // TODO: deprecate in favor of llama_get_logits_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
     LLAMA_API float * llama_get_logits(struct llama_context * ctx);
 
     // Logits for the ith token. For positive indices, Equivalent to:
@@ -970,6 +971,7 @@ extern "C" {
     // in the order they have appeared in the batch.
     // shape: [n_outputs*n_embd]
     // Otherwise, returns NULL.
+    // TODO: deprecate in favor of llama_get_embeddings_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
     LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
 
     // Get the embeddings for the ith token. For positive indices, Equivalent to:
index 6eb344736de6f398e768f2788d198df6eff24c98..a91d157e298032c914a7678dc06b67f4d7fc7b23 100644 (file)
@@ -508,12 +508,16 @@ enum llama_pooling_type llama_context::pooling_type() const {
 }
 
 float * llama_context::get_logits() {
+    output_reorder();
+
     return logits;
 }
 
 float * llama_context::get_logits_ith(int32_t i) {
     int64_t j = -1;
 
+    output_reorder();
+
     try {
         if (logits == nullptr) {
             throw std::runtime_error("no logits");
@@ -550,12 +554,16 @@ float * llama_context::get_logits_ith(int32_t i) {
 }
 
 float * llama_context::get_embeddings() {
+    output_reorder();
+
     return embd;
 }
 
 float * llama_context::get_embeddings_ith(int32_t i) {
     int64_t j = -1;
 
+    output_reorder();
+
     try {
         if (embd == nullptr) {
             throw std::runtime_error("no embeddings");
@@ -970,6 +978,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
 
     // TODO: this clear of the buffer can easily be forgotten - need something better
     embd_seq.clear();
+    output_swaps.clear();
 
     bool did_optimize = false;
 
@@ -1189,9 +1198,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
         // make the outputs have the same order they had in the user-provided batch
         // note: this is mostly relevant for recurrent models atm
         if (!sorted_output) {
-            const uint32_t n_vocab = model.vocab.n_tokens();
-            const uint64_t n_embd  = model.hparams.n_embd;
-
             GGML_ASSERT((size_t) n_outputs == out_ids.size());
 
             // TODO: is there something more efficient which also minimizes swaps?
@@ -1207,16 +1213,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
                     continue;
                 }
                 std::swap(out_ids[i], out_ids[j_min]);
-                if (logits_size > 0) {
-                    for (uint32_t k = 0; k < n_vocab; k++) {
-                        std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
-                    }
-                }
-                if (embd_size > 0) {
-                    for (uint32_t k = 0; k < n_embd; k++) {
-                        std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
-                    }
-                }
+
+                // remember the swaps and apply them lazily upon logits/embeddings access
+                output_swaps.push_back({ i, j_min });
             }
 
             std::fill(output_ids.begin(), output_ids.end(), -1);
@@ -1307,6 +1306,30 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
     return n_outputs_max;
 }
 
+void llama_context::output_reorder() {
+    const uint32_t n_vocab = model.vocab.n_tokens();
+    const uint64_t n_embd  = model.hparams.n_embd;
+
+    for (uint32_t s = 0; s < output_swaps.size(); ++s) {
+        const uint32_t i0 = output_swaps[s].i0;
+        const uint32_t i1 = output_swaps[s].i1;
+
+        if (logits_size > 0) {
+            for (uint32_t k = 0; k < n_vocab; k++) {
+                std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
+            }
+        }
+
+        if (embd_size > 0) {
+            for (uint32_t k = 0; k < n_embd; k++) {
+                std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
+            }
+        }
+    }
+
+    output_swaps.clear();
+}
+
 //
 // graph
 //
index 1601ac682ea712733766d7930f707330c618f32e..fdbe61207e8ce49535ff1c2119cb7207d5bbcb93 100644 (file)
@@ -181,6 +181,8 @@ private:
     // Returns max number of outputs for which space was reserved.
     uint32_t output_reserve(int32_t n_outputs);
 
+    void output_reorder();
+
     //
     // graph
     //
@@ -250,6 +252,13 @@ private:
 
     std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
 
+    struct swap_info {
+        uint32_t i0;
+        uint32_t i1;
+    };
+
+    std::vector<swap_info> output_swaps;
+
     ggml_backend_sched_ptr sched;
 
     ggml_backend_t backend_cpu = nullptr;