]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : pad small embedding batches (#13692)
authorGeorgi Gerganov <redacted>
Thu, 22 May 2025 13:33:39 +0000 (16:33 +0300)
committerGitHub <redacted>
Thu, 22 May 2025 13:33:39 +0000 (16:33 +0300)
ggml-ci

tools/server/server.cpp

index 7424da52385efbff98c4c65ca753209657b98060..1a08e30d28751df7f66d4916613e6fc3bb8d329e 100644 (file)
@@ -3341,6 +3341,37 @@ struct server_context {
             common_set_adapter_lora(ctx, slot_batched->lora);
         }
 
+        const bool do_encode = (params_base.embedding || params_base.reranking);
+
+        // pad the batch so that batch.n_tokens >= n_slots
+        // TODO: temporary workaround for https://github.com/ggml-org/llama.cpp/issues/13689
+        if (do_encode) {
+            const int n_slots = slots.size();
+
+            if (batch.n_tokens < n_slots) {
+                std::set<llama_seq_id> seq_ids;
+                for (int j = 0; j < batch.n_tokens; ++j) {
+                    seq_ids.insert(batch.seq_id[j][0]);
+                }
+
+                // find unused sequence id
+                llama_seq_id seq_id = -1;
+                for (int i = 0; i < n_slots; ++i) {
+                    if (seq_ids.find(i) == seq_ids.end()) {
+                        seq_id = i;
+                    }
+                }
+
+                const int n_add = n_slots - batch.n_tokens;
+
+                SRV_WRN("adding %d dummy tokens to the batch, seq_id = %d\n", n_add, seq_id);
+
+                for (int j = 0; j < n_add; ++j) {
+                    common_batch_add(batch, 0, j, { seq_id }, false);
+                }
+            }
+        }
+
         // process the created batch of tokens
         for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
             const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
@@ -3357,7 +3388,7 @@ struct server_context {
 
             int ret = 0;
 
-            if (params_base.embedding || params_base.reranking) {
+            if (do_encode) {
                 ret = llama_encode(ctx, batch_view);
             } else {
                 ret = llama_decode(ctx, batch_view);